diff --git a/training/checkpoint_utils.py b/training/checkpoint_utils.py index dacf131..d86561c 100644 --- a/training/checkpoint_utils.py +++ b/training/checkpoint_utils.py @@ -45,7 +45,8 @@ def save_checkpoint( amp, global_step, list_head_names, - keep_num=5): + keep_num=5, + optimizer=None): """ Save training state, create a separate folder for each step, and subfolders for each PFC head @@ -58,12 +59,13 @@ def save_checkpoint( global_step: Current global step count list_head_names: List of head names keep_num: Number of checkpoints to keep + optimizer: Optimizer object (AdamW) - CRITICAL for proper resume """ # Create folder for current step step_dir = os.path.join(output_dir, f"{global_step:08d}") os.makedirs(step_dir, exist_ok=True) - # Save backbone model and optimizer state (only on rank 0) + # Save backbone model and scheduler state (only on rank 0) if rank == 0: # Save backbone model (move to CPU) backbone_path = os.path.join(step_dir, "backbone.pt") @@ -81,6 +83,21 @@ def save_checkpoint( logging.info(f"Backbone, scheduler saved at step {global_step}") + # Save optimizer state per-rank (CRITICAL for proper resume) + # Optimizer contains per-rank PFC parameter moments, so each rank saves its own + if optimizer is not None: + optimizer_path = os.path.join(step_dir, f"optimizer_{rank:03d}.pt") + # Move optimizer state tensors to CPU for saving + opt_state_dict = optimizer.state_dict() + cpu_state_dict = { + 'state': {k: {sk: sv.cpu() if isinstance(sv, torch.Tensor) else sv + for sk, sv in v.items()} + for k, v in opt_state_dict['state'].items()}, + 'param_groups': opt_state_dict['param_groups'] + } + torch.save(cpu_state_dict, optimizer_path) + logging.info(f"Rank {rank}: Optimizer state saved at step {global_step}") + if isinstance(pfc_modules, list): # Each rank saves its own PFC module for head_id, (head_name, pfc) in enumerate(zip(list_head_names, pfc_modules)): @@ -163,7 +180,7 @@ def clean_old_checkpoints(output_dir, keep_num=5): def load_checkpoint(output_dir, step, backbone, pfc_modules, lr_scheduler, - amp, list_head_names): + amp, list_head_names, optimizer=None): """ Load training state from checkpoint folder at specified step @@ -175,6 +192,7 @@ def load_checkpoint(output_dir, step, backbone, pfc_modules, lr_scheduler, lr_scheduler: Learning rate scheduler amp: Automatic mixed precision object list_head_names: List of head names + optimizer: Optimizer object (AdamW) - CRITICAL for proper resume Returns: dict: Contains restored global step information @@ -282,6 +300,16 @@ def load_checkpoint(output_dir, step, backbone, pfc_modules, lr_scheduler, else: logging.warning(f"AMP state file not found: {amp_file}") + # Load optimizer state per-rank (CRITICAL for proper resume) + if optimizer is not None: + optimizer_file = os.path.join(step_dir, f"optimizer_{rank:03d}.pt") + if os.path.exists(optimizer_file): + optimizer.load_state_dict(torch.load(optimizer_file, )) + logging.info(f"Rank {rank}: Loaded optimizer state from step {step}") + else: + logging.warning(f"Rank {rank}: Optimizer state file not found: {optimizer_file}") + logging.warning("Training will resume with fresh optimizer moments - expect temporary loss spike") + return { 'global_step': step } diff --git a/training/train.py b/training/train.py index d3139f7..b7823d8 100644 --- a/training/train.py +++ b/training/train.py @@ -362,6 +362,7 @@ def _expand(name, v): lr_scheduler, None, args.list_head_names, + optimizer=opt, # Pass optimizer for proper resume with AdamW moments ) if result is not None: global_step = result["global_step"] @@ -691,6 +692,7 @@ def wrap_ddp(model): global_step=global_step, list_head_names=args.list_head_names, keep_num=20, + optimizer=opt, # Save optimizer state for proper resume ) # Also save in HuggingFace format save_hf_checkpoint(args.output, backbone, global_step=global_step, image_size=args.image_size[0]) @@ -705,6 +707,7 @@ def wrap_ddp(model): global_step=global_step, list_head_names=args.list_head_names, keep_num=20, + optimizer=opt, # Save optimizer state for proper resume ) # Also save final model in HuggingFace format save_hf_checkpoint(args.output, backbone, global_step=global_step, image_size=args.image_size[0])