diff --git a/megatron/core/optimizer/cpu_offloading/optimizer_state_offloader.py b/megatron/core/optimizer/cpu_offloading/optimizer_state_offloader.py new file mode 100644 index 0000000000..81fd116c8b --- /dev/null +++ b/megatron/core/optimizer/cpu_offloading/optimizer_state_offloader.py @@ -0,0 +1,315 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Optimizer state offloading class.""" + +from typing import TYPE_CHECKING, Dict, List, Tuple + +import torch + +if TYPE_CHECKING: + from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer + + +class OptimizerStateOffloader: + """ + Manages offloading of optimizer states and master weights to CPU. + Used with DistributedOptimizer to reduce GPU memory usage. + + Supports overlapped D2H/H2D transfers using CUDA streams. + + Master weights can be stored in two locations: + - In adam optimizer state (when use_precision_aware_optimizer_no_fp8_or_ds_fp8 is True) + - In mcore's shard_fp32_from_float16_groups + """ + + OPTIMIZER_STATE_KEYS = ('exp_avg', 'exp_avg_sq') + MASTER_WEIGHT_KEY = 'master_param' + + def __init__(self, distrib_optimizer: "DistributedOptimizer"): + """ + Args: + distrib_optimizer: The DistributedOptimizer to offload states and master weights from. + """ + self.dist_optimizer = distrib_optimizer + self.adam_optimizer = distrib_optimizer.optimizer + + # Only support TE FusedAdam optimizer for now. + try: + from transformer_engine.pytorch.optimizers import FusedAdam + + assert isinstance(self.adam_optimizer, FusedAdam), ( + f"OptimizerStateOffloader requires TE FusedAdam optimizer, " + f"but got {type(self.adam_optimizer).__name__}" + ) + except ImportError: + raise ImportError( + "OptimizerStateOffloader requires transformer_engine.pytorch.optimizers.FusedAdam" + ) + + # Check if master weights are stored in adam optimizer state + self.optimizer_contains_master_weights = self.adam_optimizer.master_weights + + # CUDA streams for async transfers + self._d2h_stream = torch.cuda.Stream() + self._h2d_stream = torch.cuda.Stream() + + # CPU buffers for optimizer states: {param: {key: cpu_tensor}} + self._opt_state_cpu_buffers: Dict[torch.Tensor, Dict[str, torch.Tensor]] = {} + + # CPU buffers for mcore master weights, matching the structure of source groups + # List[List[cpu_tensor]] + self._shard_fp32_from_float16_cpu_buffers: List[List[torch.Tensor]] = [] + + # State tracking + self._offloaded = False + self._offloaded_state_keys: Tuple[str, ...] = () + self._offloaded_mcore_master_weights = False + + # Track whether optimizer states (exp_avg, exp_avg_sq) have been initialized. + # These are lazily initialized by FusedAdam during the first optimizer.step(). + # Master weights (shard_fp32_from_float16_groups) are available from the start. + self._optimizer_states_initialized = False + + def mark_optimizer_states_initialized(self): + """ + Mark that optimizer states (exp_avg, exp_avg_sq) are now available. + Should be called after the first optimizer.step() completes. + """ + self._optimizer_states_initialized = True + + def _get_state_keys_to_offload( + self, offload_optimizer_states: bool, offload_master_weights: bool + ) -> Tuple[str, ...]: + """Get the state keys in FusedAdam to offload based on configuration.""" + keys = [] + # Skip optimizer states offloading if they haven't been initialized yet. + # Optimizer states are lazily initialized by FusedAdam during the first optimizer.step(). + if self._optimizer_states_initialized: + if offload_optimizer_states: + keys.extend(self.OPTIMIZER_STATE_KEYS) + if offload_master_weights and self.optimizer_contains_master_weights: + keys.append(self.MASTER_WEIGHT_KEY) + return tuple(keys) + + def _ensure_state_cpu_buffer( + self, param: torch.Tensor, state_key: str, gpu_tensor: torch.Tensor, pin_memory: bool = True + ) -> torch.Tensor: + """Get or create a CPU buffer for a state tensor.""" + if param not in self._opt_state_cpu_buffers: + self._opt_state_cpu_buffers[param] = {} + + if state_key not in self._opt_state_cpu_buffers[param]: + cpu_buffer = torch.empty( + gpu_tensor.size(), + dtype=gpu_tensor.dtype, + layout=gpu_tensor.layout, + device='cpu', + pin_memory=pin_memory, + ) + self._opt_state_cpu_buffers[param][state_key] = cpu_buffer + + return self._opt_state_cpu_buffers[param][state_key] + + def _offload_shard_groups( + self, + shard_groups: List[List[torch.Tensor]], + cpu_buffers: List[List[torch.Tensor]], + pin_memory: bool = True, + ): + """Offload a shard group to CPU buffers.""" + # Initialize CPU buffers on first call + if len(cpu_buffers) == 0: + for group in shard_groups: + group_buffers = [] + for gpu_tensor in group: + cpu_buffer = torch.empty( + gpu_tensor.size(), + dtype=gpu_tensor.dtype, + layout=gpu_tensor.layout, + device='cpu', + pin_memory=pin_memory, + ) + group_buffers.append(cpu_buffer) + cpu_buffers.append(group_buffers) + + # Copy D2H + for group_idx, group in enumerate(shard_groups): + for param_idx, gpu_tensor in enumerate(group): + cpu_buffer = cpu_buffers[group_idx][param_idx] + cpu_buffer.copy_(gpu_tensor, non_blocking=pin_memory) + gpu_tensor.record_stream(self._d2h_stream) + + def _offload_states( + self, + offload_optimizer_states: bool, + offload_master_weights: bool, + use_pin_memory: bool = True, + ): + """Offload optimizer states and/or master weights to CPU.""" + # Offload states from adam optimizer + self._offloaded_state_keys = self._get_state_keys_to_offload( + offload_optimizer_states, offload_master_weights + ) + states = self.adam_optimizer.state + + for param, param_state in states.items(): + for state_key in self._offloaded_state_keys: + if state_key not in param_state: + continue + + gpu_tensor = param_state[state_key] + if not isinstance(gpu_tensor, torch.Tensor) or not gpu_tensor.is_cuda: + continue + + cpu_buffer = self._ensure_state_cpu_buffer( + param, state_key, gpu_tensor, use_pin_memory + ) + cpu_buffer.copy_(gpu_tensor, non_blocking=use_pin_memory) + gpu_tensor.record_stream(self._d2h_stream) + + # Offload mcore master weights if not in optimizer state + if offload_master_weights and not self.optimizer_contains_master_weights: + self._offload_shard_groups( + self.dist_optimizer.shard_fp32_from_float16_groups, + self._shard_fp32_from_float16_cpu_buffers, + use_pin_memory, + ) + self._offloaded_mcore_master_weights = True + + def _release_states(self): + """Replace optimizer state GPU tensors with CPU tensors to free GPU memory.""" + states = self.adam_optimizer.state + + for param, param_state in states.items(): + if param not in self._opt_state_cpu_buffers: + continue + + for state_key in self._offloaded_state_keys: + if state_key not in self._opt_state_cpu_buffers[param]: + continue + + param_state[state_key].untyped_storage().resize_(0) + + if self._offloaded_mcore_master_weights: + for group in self.dist_optimizer.shard_fp32_from_float16_groups: + for gpu_tensor in group: + gpu_tensor.untyped_storage().resize_(0) + + def _reload_shard_groups( + self, + shard_groups: List[List[torch.Tensor]], + cpu_buffers: List[List[torch.Tensor]], + is_allocate_stage: bool, + ): + """Reload shard groups from CPU to GPU.""" + for group_idx, group in enumerate(shard_groups): + for param_idx, _ in enumerate(group): + cpu_buffer = cpu_buffers[group_idx][param_idx] + if is_allocate_stage: + shard_groups[group_idx][param_idx].untyped_storage().resize_( + cpu_buffer.untyped_storage().size() + ) + else: + shard_groups[group_idx][param_idx].copy_( + cpu_buffer, non_blocking=cpu_buffer.is_pinned() + ) + + def _reload_states(self, is_allocate_stage: bool): + """ + Reload optimizer states and/or master weights from CPU to GPU. + + If is_allocate_stage is True, only allocate GPU memory for the states and master weights, + but do not copy the data from CPU to GPU. Otherwise, copy the data from CPU to GPU. + The two processes are separated to make sure that the GPU memory is allocated on the + default stream to avoid fragmentation. + """ + # Reload states to adam optimizer + states = self.adam_optimizer.state + + for param, param_state in states.items(): + if param not in self._opt_state_cpu_buffers: + continue + + for state_key in self._offloaded_state_keys: + if state_key not in self._opt_state_cpu_buffers[param]: + continue + + cpu_buffer = self._opt_state_cpu_buffers[param][state_key] + if is_allocate_stage: + param_state[state_key].untyped_storage().resize_( + cpu_buffer.untyped_storage().size() + ) + else: + param_state[state_key].copy_(cpu_buffer, non_blocking=cpu_buffer.is_pinned()) + + # Reload mcore master weights if not in optimizer state + if self._offloaded_mcore_master_weights: + self._reload_shard_groups( + self.dist_optimizer.shard_fp32_from_float16_groups, + self._shard_fp32_from_float16_cpu_buffers, + is_allocate_stage, + ) + + def offload(self, offload_optimizer_states: bool = True, offload_master_weights: bool = True): + """ + Offload optimizer states and/or master weights to CPU. + Starts async D2H transfer that can overlap with other operations. + + Args: + offload_optimizer_states: Whether to offload exp_avg, exp_avg_sq. + offload_master_weights: Whether to offload master weights. + """ + if not offload_optimizer_states and not offload_master_weights: + return + + # Wait for current stream finishing updating the optimizer states. + self._d2h_stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(self._d2h_stream): + self._offload_states(offload_optimizer_states, offload_master_weights) + + self._offloaded = True + + def release_gpu_memory(self): + """ + Release GPU memory for optimizer states and master weights after D2H copy completes. + + This is separated from offload() to allow delayed GPU memory release, + which is needed for mxfp8 + overlap_param_gather case where master weights + must remain on GPU until after _copy_main_params_to_param_buffer() is called. + """ + if not self._offloaded: + return + + self._release_states() + + def reload(self): + """ + Reload optimizer states and/or master weights from CPU to GPU. + Call before optimizer.step() to ensure states are on GPU. + """ + if not self._offloaded: + return + + # Allocate GPU memory on the current stream to avoid fragmentation. + self._reload_states(is_allocate_stage=True) + + self._h2d_stream.wait_stream(self._d2h_stream) + self._h2d_stream.wait_stream(torch.cuda.current_stream()) + + # Reload states on the h2d stream to overlap with other operations. + with torch.cuda.stream(self._h2d_stream): + self._reload_states(is_allocate_stage=False) + + self._offloaded_state_keys = () + self._offloaded_mcore_master_weights = False + self._offloaded = False + + def sync_before_step(self): + """ + Wait for H2D reload to complete before optimizer.step(). + Must be called to ensure states are on GPU before optimizer uses them. + + This is separated from reload() to make it possible to move the reload ahead of time. + """ + torch.cuda.current_stream().wait_stream(self._h2d_stream) diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index 6e093f96f7..9536bc4f9e 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -49,6 +49,7 @@ from ..fp8_utils import dequantize_fp8_tensor, is_float8tensor, quantize_param_shard from ..transformer.fsdp_dtensor_checkpoint import handle_experts_in_state_dict from ..transformer.module import MegatronModule +from .cpu_offloading.optimizer_state_offloader import OptimizerStateOffloader from .grad_scaler import MegatronGradScaler from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper, param_group_identifier_keys from .optimizer_config import OptimizerConfig @@ -604,6 +605,10 @@ def __init__( self.optimizer.param_groups = [g["orig_group"] for g in self.opt_group_ranges] self.optimizer.load_state_dict(self.optimizer.state_dict()) + self._state_offloader: Optional[OptimizerStateOffloader] = None + if self.config.offload_optimizer_states: + self._state_offloader = OptimizerStateOffloader(self) + def _get_model_param_range_map(self, param: torch.nn.Parameter): """ Given a model param, get the index sub-range of the param that this @@ -2580,6 +2585,8 @@ def step_with_ready_grads(self) -> bool: Under the hood, either launch synchronous param all-gathers or get ready to launch asynchorous all-gathers that get overlapped with the next forward pass. """ + if self._state_offloader is not None: + self._state_offloader.sync_before_step() update_successful = super().step_with_ready_grads() timers = self.config.timers @@ -2600,4 +2607,22 @@ def step_with_ready_grads(self) -> bool: if timers is not None: timers('params-all-gather').stop() + if self._state_offloader is not None: + self._state_offloader.mark_optimizer_states_initialized() + return update_successful + + def offload_states(self): + """Offload states to CPU.""" + if self._state_offloader is not None: + self._state_offloader.offload() + + def reload_offloaded_states(self): + """Start async reload of offloaded states.""" + if self._state_offloader is not None: + self._state_offloader.reload() + + def release_offloaded_gpu_states(self): + """Release GPU memory after D2H completes. For delayed release case.""" + if self._state_offloader is not None: + self._state_offloader.release_gpu_memory() diff --git a/megatron/core/optimizer/optimizer_config.py b/megatron/core/optimizer/optimizer_config.py index 679878ed95..1813488d7b 100644 --- a/megatron/core/optimizer/optimizer_config.py +++ b/megatron/core/optimizer/optimizer_config.py @@ -266,6 +266,12 @@ class OptimizerConfig: pin_cpu_params: bool = True """If True, pin the optimizer parameters to CPU memory.""" + offload_optimizer_states: bool = False + """ + If True, offload optimizer states to CPU after each optimizer step and + reload them before the next optimizer step. + """ + ################ # Miscellaneous ################ diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 9aba3a7cb8..0dec84a49e 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1271,6 +1271,11 @@ def validate_args(args, defaults={}): "must be used in conjunction with `--fp8-recipe delayed`." ) + if args.offload_optimizer_states: + assert args.use_distributed_optimizer, "offload_optimizer_states is only supported with distributed optimizer" + assert args.optimizer == 'adam', "offload_optimizer_states is only supported with adam optimizer" + assert not args.use_megatron_fsdp, "offload_optimizer_states does not support Megatron-FSDP for now." + if args.non_persistent_ckpt_type == "local": assert args.non_persistent_local_ckpt_dir is not None, "Tried to use local checkpointing without specifying --local-ckpt-dir!" if args.replication: @@ -2389,6 +2394,14 @@ def _add_training_args(parser): help='Disable pinning of CPU memory for gradients.') group.add_argument('--no-pin-cpu-params', action='store_false', dest='pin_cpu_params', help='Disable pinning of CPU memory for parameters.') + group.add_argument('--offload-optimizer-states', + action='store_true', + dest='offload_optimizer_states', + help='Offload optimizer states to CPU after each optimizer step and ' + 'reload them before the next optimizer step. ' + 'Only support TE FusedAdam optimizer.' + 'Note that this still uses pure GPU optimizer instead of ' + 'HybridDeviceOptimizer for --optimizer-cpu-offload.') group.add_argument('--dataloader-type', type=str, default=None, choices=['single', 'cyclic', 'external'], help='Single pass vs multiple pass data loader') diff --git a/megatron/training/training.py b/megatron/training/training.py index 845d271f62..8aff2556d1 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1425,6 +1425,12 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch rerun_state_machine = get_rerun_state_machine() while rerun_state_machine.should_run_forward_backward(data_iterator): + # Offload optimizer states to CPU if enabled. + if args.offload_optimizer_states: + for optim_instance in optimizer.chained_optimizers: + if isinstance(optim_instance, DistributedOptimizer): + optim_instance.offload_states() + # Set grad to zero. for model_chunk in model: model_chunk.zero_grad_buffer() @@ -1458,6 +1464,14 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch if isinstance(optim_instance, DistributedOptimizer): optim_instance._copy_main_params_to_param_buffer() + # Release GPU memory for offloaded optimizer states. + # This needs to be done after _copy_main_params_to_param_buffer(). + # Separate offload and release to allow early D2H transfer to overlap with other operations. + if args.offload_optimizer_states: + for optim_instance in optimizer.chained_optimizers: + if isinstance(optim_instance, DistributedOptimizer): + optim_instance.release_offloaded_gpu_states() + # Forward pass. losses_reduced = forward_backward_func( forward_step_func=forward_step_func, @@ -2305,7 +2319,21 @@ def train( config.param_sync_func = [model_chunk.start_param_sync for model_chunk in model] if len(model) == 1: config.param_sync_func = config.param_sync_func[0] - config.finalize_model_grads_func = finalize_model_grads + + # Wrap finalize_model_grads to reload offloaded optimizer states before grad finalization. + # This allows H2D transfer to overlap with grad all-reduce. + if args.offload_optimizer_states: + + def finalize_model_grads_with_state_reload(*fmg_args, **fmg_kwargs): + # Reload offloaded states for all DistributedOptimizer instances + for optim_instance in optimizer.chained_optimizers: + if isinstance(optim_instance, DistributedOptimizer): + optim_instance.reload_offloaded_states() + return finalize_model_grads(*fmg_args, **fmg_kwargs) + + config.finalize_model_grads_func = finalize_model_grads_with_state_reload + else: + config.finalize_model_grads_func = finalize_model_grads if args.log_energy: energy_monitor.setup() diff --git a/tests/unit_tests/test_optimizer_state_offloading.py b/tests/unit_tests/test_optimizer_state_offloading.py new file mode 100644 index 0000000000..baaab35518 --- /dev/null +++ b/tests/unit_tests/test_optimizer_state_offloading.py @@ -0,0 +1,337 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Unit tests for OptimizerStateOffloader.""" + +import pytest +import torch +import torch.nn as nn + +from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig +from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer +from megatron.core.transformer import TransformerConfig +from tests.unit_tests.test_utilities import Utils + +try: + from transformer_engine.pytorch.optimizers import FusedAdam # noqa: F401 + + TE_FUSED_ADAM_AVAILABLE = True +except ImportError: + TE_FUSED_ADAM_AVAILABLE = False + + +class SimpleModel(nn.Module): + """Simple model for testing.""" + + def __init__(self, hidden_size=256): + super().__init__() + self.fc1 = nn.Linear(hidden_size, hidden_size) + self.fc2 = nn.Linear(hidden_size, hidden_size) + + def forward(self, x): + return self.fc2(torch.relu(self.fc1(x))) + + +def create_model_and_optimizer(hidden_size=256, offload_optimizer_states=True, **optimizer_kwargs): + """Helper to create model and optimizer for tests.""" + model = SimpleModel(hidden_size=hidden_size).bfloat16().cuda() + ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True) + model = DistributedDataParallel( + TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model + ) + + default_config = dict( + optimizer='adam', + bf16=True, + lr=0.001, + use_distributed_optimizer=True, + offload_optimizer_states=offload_optimizer_states, + ) + default_config.update(optimizer_kwargs) + + optimizer_config = OptimizerConfig(**default_config) + optim = get_megatron_optimizer(optimizer_config, [model]) + return model, optim + + +def run_forward_backward_step(model, optim, hidden_size=256): + """Run a single forward-backward-step cycle.""" + input_tensor = torch.randn(8, hidden_size, dtype=torch.bfloat16, device='cuda') + output = model(input_tensor) + output.sum().backward() + optim.step() + optim.zero_grad() + + +# ============================================================================= +# Test 1: Basic OptimizerStateOffloader Initialization +# ============================================================================= +@pytest.mark.skipif(not TE_FUSED_ADAM_AVAILABLE, reason="Requires TE FusedAdam") +def test_offloader_initialization(): + """Test that OptimizerStateOffloader initializes correctly.""" + Utils.initialize_model_parallel() + model, optim = create_model_and_optimizer() + dist_optim = optim.chained_optimizers[0] + + # Offloader is created in __init__ when offload_optimizer_states=True + assert dist_optim._state_offloader is not None + offloader = dist_optim._state_offloader + + # Verify offloader properties + assert offloader.adam_optimizer is not None + assert offloader._d2h_stream is not None + assert offloader._h2d_stream is not None + assert offloader._offloaded is False + + # Before first step, optimizer states are not initialized yet + assert offloader._optimizer_states_initialized is False + + # Run one step to initialize optimizer states + run_forward_backward_step(model, optim) + + # After first step, optimizer states should be marked as initialized + assert offloader._optimizer_states_initialized is True + Utils.destroy_model_parallel() + + +# ============================================================================= +# Test 2: Early Master Weight Offloading Before First Step +# ============================================================================= +@pytest.mark.skipif(not TE_FUSED_ADAM_AVAILABLE, reason="Requires TE FusedAdam") +def test_early_master_weight_offloading(): + """Test that master weights can be offloaded before the first optimizer step.""" + Utils.initialize_model_parallel() + model, optim = create_model_and_optimizer() + dist_optim = optim.chained_optimizers[0] + + # Offloader is created in __init__ + assert dist_optim._state_offloader is not None + offloader = dist_optim._state_offloader + + # Before first step, optimizer states are not initialized + assert offloader._optimizer_states_initialized is False + + # Capture original master weights before offload + original_master_weights = [] + for group in dist_optim.shard_fp32_from_float16_groups: + group_weights = [tensor.clone() for tensor in group] + original_master_weights.append(group_weights) + + # Offload before first step - should only offload master weights + offloader.offload() + offloader.release_gpu_memory() + torch.cuda.synchronize() + + # Verify master weights were offloaded (storage resized to 0) + for group in dist_optim.shard_fp32_from_float16_groups: + for tensor in group: + assert tensor.untyped_storage().size() == 0, "Master weight should be offloaded" + + # Reload master weights + offloader.reload() + offloader.sync_before_step() + + # Verify master weights match after reload + for group_idx, group in enumerate(dist_optim.shard_fp32_from_float16_groups): + for param_idx, tensor in enumerate(group): + original = original_master_weights[group_idx][param_idx] + torch.testing.assert_close( + tensor, + original, + msg=f"Master weight [{group_idx}][{param_idx}] mismatch after offload/reload", + ) + + # Now run a step and verify optimizer states can be offloaded after + run_forward_backward_step(model, optim) + assert offloader._optimizer_states_initialized is True + + Utils.destroy_model_parallel() + + +# ============================================================================= +# Test 3: Offload and Reload Correctness +# ============================================================================= +@pytest.mark.skipif(not TE_FUSED_ADAM_AVAILABLE, reason="Requires TE FusedAdam") +@pytest.mark.parametrize("offload_optimizer_states", [True, False]) +@pytest.mark.parametrize("offload_master_weights", [True, False]) +def test_offload_reload_correctness(offload_optimizer_states, offload_master_weights): + """Test that offload/reload preserves optimizer state values.""" + if not offload_optimizer_states and not offload_master_weights: + pytest.skip("At least one offload type required") + + Utils.initialize_model_parallel() + model, optim = create_model_and_optimizer() + dist_optim = optim.chained_optimizers[0] + + # Run steps to build up optimizer state + for _ in range(3): + run_forward_backward_step(model, optim) + + offloader = dist_optim._state_offloader + + # Capture original states before offload + original_states = {} + for param, state in offloader.adam_optimizer.state.items(): + original_states[param] = { + k: v.clone() for k, v in state.items() if isinstance(v, torch.Tensor) + } + + # Offload + offloader.offload( + offload_optimizer_states=offload_optimizer_states, + offload_master_weights=offload_master_weights, + ) + + # Release GPU memory + offloader.release_gpu_memory() + torch.cuda.synchronize() + + # Reload + offloader.reload() + offloader.sync_before_step() + + # Verify states match after reload + for param, state in offloader.adam_optimizer.state.items(): + if param in original_states: + for key, original_tensor in original_states[param].items(): + if key in state and isinstance(state[key], torch.Tensor): + reloaded_tensor = state[key] + assert reloaded_tensor.device.type == 'cuda', f"State {key} should be on GPU" + torch.testing.assert_close( + reloaded_tensor, + original_tensor, + msg=f"State {key} mismatch after offload/reload", + ) + Utils.destroy_model_parallel() + + +# ============================================================================= +# Test 4: GPU Memory Release Verification +# ============================================================================= +@pytest.mark.skipif(not TE_FUSED_ADAM_AVAILABLE, reason="Requires TE FusedAdam") +def test_gpu_memory_release(): + """Test that GPU memory is actually freed after release_gpu_memory().""" + Utils.initialize_model_parallel() + # Use larger model for measurable memory impact + model, optim = create_model_and_optimizer(hidden_size=1024) + dist_optim = optim.chained_optimizers[0] + + # Initialize optimizer states + run_forward_backward_step(model, optim, hidden_size=1024) + + offloader = dist_optim._state_offloader + + # Measure memory before offload + torch.cuda.synchronize() + torch.cuda.empty_cache() + memory_before = torch.cuda.memory_allocated() + + # Offload and release + offloader.offload() + offloader.release_gpu_memory() + + # Wait for async operations + torch.cuda.synchronize() + torch.cuda.empty_cache() + memory_after = torch.cuda.memory_allocated() + + # Memory should decrease + memory_freed = memory_before - memory_after + assert memory_freed > 0, f"Expected memory to be freed, but got {memory_freed} bytes difference" + Utils.destroy_model_parallel() + + +# ============================================================================= +# Test 5: Multiple Offload/Reload Cycles +# ============================================================================= +@pytest.mark.skipif(not TE_FUSED_ADAM_AVAILABLE, reason="Requires TE FusedAdam") +def test_multiple_offload_reload_cycles(): + """Test that multiple offload/reload cycles work correctly.""" + Utils.initialize_model_parallel() + model, optim = create_model_and_optimizer() + dist_optim = optim.chained_optimizers[0] + + # Initialize + run_forward_backward_step(model, optim) + + offloader = dist_optim._state_offloader + + # Run multiple cycles + for cycle in range(5): + # Offload + offloader.offload() + offloader.release_gpu_memory() + + # Reload + offloader.reload() + offloader.sync_before_step() + + # Run optimizer step + run_forward_backward_step(model, optim) + + # Verify model can still produce valid outputs + input_tensor = torch.randn(8, 256, dtype=torch.bfloat16, device='cuda') + output = model(input_tensor) + assert not output.isnan().any(), "Model output contains NaN after multiple cycles" + Utils.destroy_model_parallel() + + +# ============================================================================= +# Test 6: Training Correctness with Offloading +# ============================================================================= +@pytest.mark.skipif(not TE_FUSED_ADAM_AVAILABLE, reason="Requires TE FusedAdam") +def test_training_correctness_with_offloading(): + """Test that training with offloading produces same results as without.""" + Utils.initialize_model_parallel() + torch.manual_seed(42) + + # Model 1: with offloading + model1, optim1 = create_model_and_optimizer(offload_optimizer_states=True, lr=0.01) + + # Model 2: without offloading (reference) + torch.manual_seed(42) + model2, optim2 = create_model_and_optimizer(offload_optimizer_states=False, lr=0.01) + + # Train both models + n_steps = 10 + torch.manual_seed(123) + dist_optim1 = optim1.chained_optimizers[0] + + # Offloader is created in __init__ when offload_optimizer_states=True + assert dist_optim1._state_offloader is not None + offloader = dist_optim1._state_offloader + + for step in range(n_steps): + input_tensor = torch.randn(8, 256, dtype=torch.bfloat16, device='cuda') + + # Model 1 with offloading + # Offload states (master weights can be offloaded from the start, + # optimizer states will be skipped until after first step) + offloader.offload() + offloader.release_gpu_memory() + + output1 = model1(input_tensor) + loss1 = output1.sum() + loss1.backward() + + offloader.reload() + offloader.sync_before_step() + optim1.step() + optim1.zero_grad() + + # Model 2 without offloading + output2 = model2(input_tensor) + loss2 = output2.sum() + loss2.backward() + optim2.step() + optim2.zero_grad() + + # Compare final model weights + for (n1, p1), (n2, p2) in zip(model1.named_parameters(), model2.named_parameters()): + torch.testing.assert_close( + p1.data, + p2.data, + atol=1e-5, + rtol=1e-4, + msg=f"Parameter {n1} mismatch between offloaded and non-offloaded training", + ) + Utils.destroy_model_parallel()