From 8428937a39e636b8d4647f0acf5d045cdc5ecfbe Mon Sep 17 00:00:00 2001 From: ruit Date: Sat, 13 Dec 2025 22:02:23 -0800 Subject: [PATCH 1/6] can run grpo workflow, but reward abnormal Signed-off-by: ruit --- examples/configs/grpo_math_1B.yaml | 12 ++ nemo_rl/algorithms/grpo.py | 88 +++++++++++-- nemo_rl/models/generation/interfaces.py | 4 +- nemo_rl/models/generation/lora.py | 120 ++++++++++++++++++ .../models/generation/vllm/vllm_backend.py | 74 ++++++++++- .../models/generation/vllm/vllm_generation.py | 18 ++- nemo_rl/models/generation/vllm/vllm_worker.py | 49 ++++++- nemo_rl/models/policy/lm_policy.py | 20 ++- .../policy/workers/dtensor_policy_worker.py | 5 + .../workers/dtensor_policy_worker_v2.py | 44 ++++++- .../policy/workers/megatron_policy_worker.py | 6 +- .../models/generation/test_vllm_generation.py | 112 +++++++++++++++- .../models/policy/test_dtensor_worker_v2.py | 26 ++++ 13 files changed, 548 insertions(+), 30 deletions(-) create mode 100644 nemo_rl/models/generation/lora.py diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 1dd9639472..46982537ad 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -90,6 +90,18 @@ policy: tensor_parallel_size: 1 context_parallel_size: 1 custom_parallel_plan: null + # LoRA (Low-Rank Adaptation) Configuration + lora_cfg: + enabled: True # Set to True to enable LoRA fine-tuning + target_modules: [] # List of module names to apply LoRA (empty list with match_all_linear=true applies to all linear layers) + exclude_modules: [] # List of module names to exclude from LoRA + match_all_linear: true # If True, applies LoRA to all linear layers (overrides target_modules) + dim: 8 # LoRA rank (r): lower rank = fewer parameters but less capacity. Typical values: 4, 8, 16, 32, 64 + alpha: 32 # LoRA scaling factor: effective learning rate multiplier = alpha/dim. Typical values: 16, 32, 64 + dropout: 0.0 # Dropout probability applied to LoRA layers (0.0 = no dropout) + dropout_position: "post" # Where to apply dropout: "pre" (before LoRA) or "post" (after LoRA) + lora_A_init: "xavier" # Initialization method for LoRA A matrix: "xavier" or "uniform" + use_triton: true # Use Triton-optimized kernels for LoRA (faster but requires flash-attn). Disable when tensor_parallel_size > 1 megatron_cfg: enabled: false diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 8ab62d00fb..f6fd03265e 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -80,6 +80,16 @@ # =============================================================================== TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) +# ANSI color codes +CYAN = "\033[96m" +GREEN = "\033[92m" +YELLOW = "\033[93m" +BLUE = "\033[94m" +MAGENTA = "\033[95m" +RED = "\033[91m" +BOLD = "\033[1m" +RESET = "\033[0m" + class RewardScalingConfig(TypedDict): """Configure linear reward scaling with clamping. @@ -460,6 +470,12 @@ def setup( ) policy_config["megatron_cfg"]["train_iters"] = total_train_iters + if policy_config.get("dtensor_cfg", {}).get("enabled", False): + lora_cfg = policy_config.get("dtensor_cfg", {}).get("lora_cfg", {}) + if lora_cfg.get("enabled", False): + # Override the vLLM lora config with the DTensor lora config + generation_config["vllm_cfg"]["lora_cfg"] = lora_cfg + # Define initialization functions that will be used in all paths def init_policy(): """Initialize policy training workers.""" @@ -908,6 +924,8 @@ def refit_policy_generation( _refit_buffer_size_gb: Optional[int] = None, timer: Optional[Timer] = None, kv_scales: Optional[dict[str, float]] = None, + refit_base_model_weights: Optional[bool] = True, + refit_lora_weights: Optional[bool] = False, ) -> None: """Refit the policy generation interface with the latest policy weights. @@ -920,18 +938,13 @@ def refit_policy_generation( timer: Optional Timer used to time the prepare/transfer/update phase kv_scales: Optional dictionary of KV cache scales for FP8 quantization. """ - if colocated_inference: - policy.offload_before_refit() - policy_generation.prepare_for_generation(tags=["weights"]) - - # Create a context manager that does nothing when timer is None - timer_context = ( - timer.time("prepare_for_generation/transfer_and_update_weights") - if timer is not None - else nullcontext() + assert refit_base_model_weights or refit_lora_weights, ( + "refit_base_model_weights and refit_lora_weights cannot be both False" ) - with timer_context: - # update weights + + def _perform_refit_weights( + refit_base_model_weights: bool, refit_lora_weights: bool + ): update_success = False if colocated_inference: # get model param keys, which is grouped by size @@ -946,9 +959,15 @@ def refit_policy_generation( ) futures_train = policy.stream_weights_via_ipc_zmq( - buffer_size_bytes=buffer_size_bytes, kv_scales=kv_scales + buffer_size_bytes=buffer_size_bytes, + kv_scales=kv_scales, + refit_base_model_weights=refit_base_model_weights, + refit_lora_weights=refit_lora_weights, + ) + futures_inference = policy_generation.update_weights_via_ipc_zmq( + refit_base_model_weights=refit_base_model_weights, + refit_lora_weights=refit_lora_weights, ) - futures_inference = policy_generation.update_weights_via_ipc_zmq() # wait for all futures to complete ray.get(futures_train) results = ray.get(futures_inference) @@ -971,6 +990,31 @@ def refit_policy_generation( "a problem within the generation backend (e.g., vLLM worker).\n" ) raise RuntimeError(error_message) + return update_success + + if colocated_inference: + policy.offload_before_refit() + policy_generation.prepare_for_generation(tags=["weights"]) + + # Create a context manager that does nothing when timer is None + timer_context = ( + timer.time("prepare_for_generation/transfer_and_update_weights") + if timer is not None + else nullcontext() + ) + with timer_context: + update_success = False + if refit_base_model_weights: + update_success = _perform_refit_weights( + refit_base_model_weights=True, refit_lora_weights=False + ) + if refit_lora_weights: + update_success = ( + _perform_refit_weights( + refit_base_model_weights=False, refit_lora_weights=True + ) + and update_success + ) if colocated_inference: policy.offload_after_refit() @@ -1013,6 +1057,8 @@ def grpo_train( policy_generation = policy # type: ignore NEED_REFIT = False POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running + REFIT_BASE_MODEL_WEIGHTS = True + REFIT_LORA_WEIGHTS = policy.lora_enabled assert policy_generation is not None # for mypy type check # Check if we need to sync KV cache scales @@ -1044,8 +1090,16 @@ def grpo_train( if val_at_start and current_step == 0: print("\nšŸ” Running initial validation...", flush=True) if NEED_REFIT and POLICY_GENERATION_STALE: - refit_policy_generation(policy, policy_generation, colocated_inference) + refit_policy_generation( + policy, + policy_generation, + colocated_inference, + refit_base_model_weights=REFIT_BASE_MODEL_WEIGHTS, + refit_lora_weights=REFIT_LORA_WEIGHTS, + ) POLICY_GENERATION_STALE = False + # Disable base model weights refit after first refit if enable lora weights refit + REFIT_BASE_MODEL_WEIGHTS = False if REFIT_LORA_WEIGHTS else True else: policy_generation.prepare_for_generation() val_metrics, validation_timings = validate( @@ -1139,8 +1193,11 @@ def grpo_train( colocated_inference, timer=timer, kv_scales=kv_scales_cache if sync_kv_scales else None, + refit_base_model_weights=REFIT_BASE_MODEL_WEIGHTS, + refit_lora_weights=REFIT_LORA_WEIGHTS, ) POLICY_GENERATION_STALE = False + REFIT_BASE_MODEL_WEIGHTS = False if REFIT_LORA_WEIGHTS else True else: if colocated_inference: policy.offload_after_refit() # unload optimizer to make space for generation @@ -1377,8 +1434,11 @@ def grpo_train( policy_generation, colocated_inference, kv_scales=kv_scales_cache if sync_kv_scales else None, + refit_base_model_weights=REFIT_BASE_MODEL_WEIGHTS, + refit_lora_weights=REFIT_LORA_WEIGHTS, ) POLICY_GENERATION_STALE = False + REFIT_BASE_MODEL_WEIGHTS = False if REFIT_LORA_WEIGHTS else True else: if colocated_inference: policy.offload_after_refit() # unload optimizer to make space for generation diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index d134027bdf..31f9536b02 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -245,7 +245,9 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: """Prepare the info for refit.""" raise NotImplementedError - def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]: + def update_weights_via_ipc_zmq( + self, refit_base_model_weights: bool = True, refit_lora_weights: bool = False + ) -> list[ray.ObjectRef]: """Update the model weights from the given IPC handles.""" raise NotImplementedError diff --git a/nemo_rl/models/generation/lora.py b/nemo_rl/models/generation/lora.py new file mode 100644 index 0000000000..ae058857ad --- /dev/null +++ b/nemo_rl/models/generation/lora.py @@ -0,0 +1,120 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Optional + +from vllm.lora.request import LoRARequest + + +class LoRARequestWithCfgAndWeights(LoRARequest): + lora_cfg: Optional[dict] = None + lora_weights: Optional[dict[str, Any]] = None + + +def patched_load_adapter(self, lora_request: LoRARequestWithCfgAndWeights): + try: + supported_lora_modules = self._adapter_manager.supported_lora_modules + packed_modules_mapping = self._adapter_manager.packed_modules_mapping + expected_lora_lst: list[str] = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_lst.extend(packed_modules_mapping[module]) + else: + expected_lora_lst.append(module) + if module == "experts": + expected_lora_lst.append(module) + expected_lora_modules = set(expected_lora_lst) + lora_weights = None + + from vllm.lora.peft_helper import PEFTHelper + + if isinstance(lora_request, LoRARequestWithCfgAndWeights): + lora_cfg = lora_request.lora_cfg + lora_weights = lora_request.lora_weights + peft_helper = PEFTHelper.from_dict(lora_cfg) + else: + raise ValueError(f"Unsupported LoRA request type: {type(lora_request)}") + + # Validates the LoRA configuration against requirements before + # loading weights, throwing an exception if validation fails. + peft_helper.validate_legal(self.lora_config) + + # For some models like Qwen2VL, we need to use hf_to_vllm_mapper + # to ensure correct loading of lora weights. + model = self._adapter_manager.model + hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None) + print(f"hf_to_vllm_mapper in lora.patched_load_adapter: {hf_to_vllm_mapper}") + if isinstance(lora_request, LoRARequestWithCfgAndWeights): + lora = self._lora_model_cls.from_lora_tensors( + lora_model_id=lora_request.lora_int_id, + tensors=lora_weights, + peft_helper=peft_helper, + device="cpu", + dtype=self.lora_config.lora_dtype, + embeddings=None, + target_embedding_padding=self.vocab_size + + self.lora_config.lora_extra_vocab_size, + embedding_modules=self.embedding_modules, + embedding_padding_modules=self.embedding_padding_modules, + weights_mapper=hf_to_vllm_mapper, + ) + + else: + raise ValueError(f"Unsupported LoRA request type: {type(lora_request)}") + + except FileNotFoundError as e: + # FileNotFoundError should be raised if both + # - No adapter found to download from huggingface (or in + # offline mode) + # - No local adapter files found at `lora_request.lora_path` + # For NotFoundError + raise ValueError( + f"Loading lora {lora_request.lora_name} failed: No adapter " + f"found for {lora_request.lora_path}" + ) from e + except Exception as e: + # For BadRequestError + raise e + + if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: + raise ValueError( + f"LoRA added vocab size {lora.extra_vocab_size} is greater than lora_extra_vocab_size " + f"{self.lora_config.lora_extra_vocab_size}." + ) + return lora + + +def apply_lora_patches(): + # func_path = "vllm.lora.worker_manager.LRUCacheWorkerLoRAManager.load_adapter" + # patcher = patch(func_path, patched_load_adapter) + from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager + + setattr(LRUCacheWorkerLoRAManager, "_load_adapter", patched_load_adapter) + + +lora_int_id = 0 + + +# Note: Not sure put it here or in nemo_rl/models/generation/vllm/utils.py +def get_vllm_lora_metadata() -> dict[str, Any]: + global lora_int_id + lora_int_id += 1 # Can be any unique id exclude 0 + lora_name = f"{lora_int_id}" + lora_path = "dummy_lora_path" + return { + "lora_name": lora_name, + "lora_int_id": lora_int_id, + "lora_path": lora_path, + } diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 1e947ed444..6e06372f00 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -125,8 +125,27 @@ def _maybe_process_fp8_kv_cache(self) -> None: target_device, ) + def _apply_weight_name_mapping( + self, weights: list[tuple[str, torch.Tensor]] + ) -> list[tuple[str, torch.Tensor]]: + """Apply weight name mapping if LoRA is enabled.""" + new_weights = [] + for name, w in weights: + new_name = name + if ".self_attn." in name and name.endswith("_proj.weight"): + new_name = name.replace("_proj.weight", "_proj.base_layer.weight") + if ".mlp." in name and name.endswith("_proj.weight"): + new_name = name.replace("_proj.weight", "_proj.base_layer.weight") + new_weights.append((new_name, w)) + return new_weights + @wrap_with_nvtx_name("vllm_internal_worker_extension/update_weights_via_ipc_zmq") - def update_weights_via_ipc_zmq(self) -> bool: + def update_weights_via_ipc_zmq( + self, + lora_config: dict[str, Any] = {}, + refit_base_model_weights: bool = False, + refit_lora_weights: bool = True, + ) -> bool: """Receive and update model weights via ZMQ IPC socket. Returns: @@ -183,7 +202,52 @@ def update_weights_via_ipc_zmq(self) -> bool: # the fp8 load_weights additionally casts bf16 weights into fp8 fp8.load_weights(weights, self.model_runner) else: - self.model_runner.model.load_weights(weights=weights) + if refit_base_model_weights: + # Apply weight name mapping if LoRA is enabled + if ( + lora_config + and "enabled" in lora_config + and lora_config["enabled"] + ): + weights = self._apply_weight_name_mapping(weights) + self.model_runner.model.load_weights(weights=weights) + elif refit_lora_weights: + assert lora_config, ( + "lora_config is not provided, can not refit lora weights" + ) + from nemo_rl.models.generation.lora import ( + LoRARequestWithCfgAndWeights, + get_vllm_lora_metadata, + ) + + # Convert vLLM LoRAConfig object to dict for PEFTHelper + # LoRAConfig(max_lora_rank=8, max_loras=1, fully_sharded_loras=False, max_cpu_loras=1, lora_dtype=torch.bfloat16, lora_extra_vocab_size=256, default_mm_loras=None, bias_enabled=False) + + lora_cfg_dict = dict( + { + "r": lora_config["dim"], + "lora_alpha": lora_config["alpha"], + "target_modules": lora_config["target_modules"], + } + ) + lora_metadata = get_vllm_lora_metadata() + # Note: We don't need to remove the lora if it is already set max_loras = 1 + self.remove_lora(lora_id=lora_metadata["lora_int_id"]) + lora_request = LoRARequestWithCfgAndWeights( + **lora_metadata, + lora_cfg=lora_cfg_dict, + lora_weights=dict( + { + name_weight[0]: name_weight[1] + for name_weight in weights + } + ), + ) + self.add_lora(lora_request=lora_request) + else: + raise ValueError( + "refit_base_model_weights and refit_lora_weights cannot be both False" + ) torch.cuda.current_stream().synchronize() @@ -273,3 +337,9 @@ def start_gpu_profiling(self) -> None: def stop_gpu_profiling(self) -> None: """Stop GPU profiling.""" torch.cuda.profiler.stop() + + def get_lora_counts(self) -> int: + """Get the number of LoRA layers from the vLLM engine.""" + results = self.list_loras() + print(f"Results: {results}") + return results diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 93540ebe82..80db9aa75d 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -767,7 +767,9 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: # Wait for all futures to complete ray.get(futures) - def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]: + def update_weights_via_ipc_zmq( + self, refit_base_model_weights: bool = True, refit_lora_weights: bool = False + ) -> list[ray.ObjectRef]: """Update weights of the policy using IPC handles via ZMQ socket.""" if not self.worker_group or not self.worker_group.workers: raise RuntimeError("Worker group is not initialized") @@ -783,6 +785,8 @@ def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]: futures = self.worker_group.run_all_workers_single_data( method_name, run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], + refit_base_model_weights=refit_base_model_weights, + refit_lora_weights=refit_lora_weights, ) # this function should co-work with lm_policy, so we should wait for all futures to complete outside @@ -916,3 +920,15 @@ def requires_kv_scale_sync(self) -> bool: return "kv_cache_dtype" in self.cfg["vllm_cfg"] and self.cfg["vllm_cfg"][ "kv_cache_dtype" ].startswith("fp8") + + def get_lora_layers(self) -> list[dict[str, Any]]: + """Get the LoRA layers from the vLLM engine.""" + futures = self.worker_group.run_all_workers_single_data("get_lora_layers") + results = ray.get(futures) + return results + + def get_lora_counts(self): + """Get the number of LoRA from the vLLM engine.""" + futures = self.worker_group.run_all_workers_single_data("get_lora_counts") + results = ray.get(futures) + return results diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index 75e3334d4a..1a46e87760 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -139,6 +139,7 @@ def __init__( self.enable_expert_parallel = self.expert_parallel_size > 1 self.gpu_memory_utilization = self.cfg["vllm_cfg"]["gpu_memory_utilization"] self.precision = self.cfg["vllm_cfg"]["precision"] + self.lora_cfg = self.cfg["vllm_cfg"].get("lora_cfg", None) self.fraction_of_gpus = fraction_of_gpus self.is_model_owner = bundle_indices is not None @@ -388,6 +389,14 @@ def _patch_vllm_vit_flash_attn_backend(): ) # disable quantization vllm_kwargs["hf_overrides"]["quantization_config"] = {} + # Lora is enabled, add it to the vllm kwargs + if self.lora_cfg is not None and self.lora_cfg["enabled"]: + from nemo_rl.models.generation.lora import apply_lora_patches + + apply_lora_patches() + vllm_kwargs["enable_lora"] = True + vllm_kwargs["max_loras"] = 1 # only support one lora adapter + vllm_kwargs["max_lora_rank"] = self.lora_cfg["dim"] llm_kwargs = dict( model=self.model_name, @@ -719,7 +728,9 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,)) @wrap_with_nvtx_name("vllm_genertion_worker/update_weights_via_ipc_zmq") - def update_weights_via_ipc_zmq(self) -> bool: + def update_weights_via_ipc_zmq( + self, refit_base_model_weights: bool = True, refit_lora_weights: bool = False + ) -> bool: """Update weights from IPC handles via ZMQ socket.""" try: assert self.llm is not None, ( @@ -733,7 +744,7 @@ def update_weights_via_ipc_zmq(self) -> bool: result_or_coro = self.llm.collective_rpc( "update_weights_via_ipc_zmq", - args=tuple(), + args=(self.lora_cfg, refit_base_model_weights, refit_lora_weights), ) worker_result = result_or_coro[0] @@ -781,6 +792,40 @@ def update_weights_from_collective(self) -> bool: traceback.print_exc() return False + def get_lora_layers(self) -> list[dict[str, Any]]: + """Get the LoRA layers from the vLLM engine.""" + + def _get_lora_layers(self): + model = self.get_model() + if model is None: + return [] + + from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA + + details = [] + for name, module in model.named_modules(): + if isinstance(module, BaseLinearLayerWithLoRA): + a_shapes = [tuple(t.shape) for t in module.lora_a_stacked] + b_shapes = [tuple(t.shape) for t in module.lora_b_stacked] + a_weights = [t for t in module.lora_a_stacked] + b_weights = [t for t in module.lora_b_stacked] + details.append( + { + "name": name, # layer name + "a_weights": a_weights, + "b_weights": b_weights, + } + ) + return details + + results = self.llm.collective_rpc(_get_lora_layers) + return results + + def get_lora_counts(self) -> int: + """Get the number of LoRA layers from the vLLM engine.""" + results = self.llm.collective_rpc("get_lora_counts") + return results + def reset_prefix_cache(self): """Reset the prefix cache of vLLM engine.""" assert self.llm is not None, ( diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 144683c95c..19ca457e86 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -86,6 +86,9 @@ def __init__( "Configure either Megatron (policy.megatron_cfg.enabled=true) or " "DTensor (policy.dtensor_cfg.enabled=true), not both." ) + # Default to False, will be overridden if LoRA is enabled + self.lora_enabled = False + if megatron_enable: worker_builder_cls = "nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker" tp_size = config["megatron_cfg"]["tensor_model_parallel_size"] @@ -109,6 +112,8 @@ def __init__( # Check if _v2 is enabled in dtensor_cfg (defaults to False for backward compatibility) use_v2 = config.get("dtensor_cfg", {}).get("_v2", False) + lora_cfg = config.get("dtensor_cfg", {}).get("lora_cfg", {}) + self.lora_enabled = lora_cfg.get("enabled", False) if use_v2: worker_builder_cls = "nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2" @@ -118,10 +123,9 @@ def __init__( "if you are running a custom container or baremetal, you may need to set this variable manually. Example: export TORCH_CUDA_ARCH_LIST='9.0 10.0'" ) else: - assert ( - config["dtensor_cfg"].get("lora_cfg", {}).get("enabled", False) - is False - ), "LoRA is not supported for DTensorPolicyWorker V1" + assert not self.lora_enabled, ( + "LoRA is not supported for DTensorPolicyWorker V1" + ) worker_builder_cls = "nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker" tp_size = config["dtensor_cfg"]["tensor_parallel_size"] @@ -758,13 +762,19 @@ def get_free_memory_bytes(self) -> int: return free_memory_bytes def stream_weights_via_ipc_zmq( - self, buffer_size_bytes: int, kv_scales: Optional[dict[str, float]] = None + self, + buffer_size_bytes: int, + kv_scales: Optional[dict[str, float]] = None, + refit_base_model_weights: bool = True, + refit_lora_weights: bool = True, ) -> list[ray.ObjectRef]: """Send the weights for IPC handles via ZMQ socket.""" futures = self.worker_group.run_all_workers_single_data( "stream_weights_via_ipc_zmq", buffer_size_bytes=buffer_size_bytes, kv_scales=kv_scales, + refit_base_model_weights=refit_base_model_weights, + refit_lora_weights=refit_lora_weights, ) return futures diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index 2903307c8b..ccd8d6053e 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -1680,8 +1680,13 @@ def stream_weights_via_ipc_zmq( self, buffer_size_bytes: int = 0, kv_scales: Optional[dict[str, float]] = None, + refit_base_model_weights: bool = True, + refit_lora_weights: bool = False, ) -> None: """Stream model weights to peer process via ZMQ IPC socket.""" + assert refit_base_model_weights and refit_lora_weights == False, ( + "dtensor v1 not support lora. refit_lora_weights must be False" + ) if kv_scales is not None: raise NotImplementedError( "FP8 kvcache is not currently supported for DTensor path, we will support it in the future." diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 738146a7e2..02b3d45801 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -1689,6 +1689,18 @@ def return_model_config(self) -> dict[str, Any]: """ return self.model.config + def _is_lora_weight(self, name: str) -> bool: + """Check if the weight is a lora weight.""" + return ( + name.endswith(".lora_A.weight") + or name.endswith(".lora_B.weight") + or name.endswith(".lora_scaling.weight") + ) + + def _is_base_model_weight(self, name: str) -> bool: + """Check if the weight is a base model weight.""" + return not self._is_lora_weight(name) + @torch.no_grad() def prepare_refit_info(self) -> Optional[dict[str, Any]]: """Prepare state dict metadata for weight refitting and IPC streaming.""" @@ -1719,6 +1731,8 @@ def stream_weights_via_ipc_zmq( self, buffer_size_bytes: int = 0, kv_scales: Optional[dict[str, float]] = None, + refit_base_model_weights: bool = True, + refit_lora_weights: bool = False, ) -> None: """Stream model weights to peer process via ZMQ IPC socket.""" if kv_scales is not None: @@ -1734,8 +1748,18 @@ def stream_weights_via_ipc_zmq( from nemo_rl.models.policy.utils import stream_weights_via_ipc_zmq_impl def dtensor_params_generator(): - """Generator that yields (name, tensor) pairs, converting DTensors to local tensors.""" + """Generator that yields (name, tensor) pairs, converting DTensors to local tensors. + + Only yields LoRA weights when LoRA is enabled, otherwise yields all weights. + """ for name, tensor in self.model.state_dict().items(): + # Skip base model weights if skip_base_model_weights is True + if self._is_base_model_weight(name) and not refit_base_model_weights: + continue + + if self._is_lora_weight(name) and not refit_lora_weights: + continue + if isinstance(tensor, DTensor): # Convert DTensor to full tensor for streaming full_tensor = tensor.full_tensor() @@ -1759,7 +1783,10 @@ def dtensor_params_generator(): @torch.no_grad() def broadcast_weights_for_collective( - self, kv_scales: Optional[dict[str, float]] = None + self, + kv_scales: Optional[dict[str, float]] = None, + refit_base_model_weights: bool = True, + refit_lora_weights: bool = False, ) -> None: """Broadcast the weights for collective communication.""" if kv_scales is not None: @@ -1784,8 +1811,19 @@ def _dtensor_post_iter_func(tensor, dtype): # param_iterator will return (name, tensor), we only need tensor dtensor_post_iter_func = lambda x: _dtensor_post_iter_func(x[1], self.dtype) + # Filter state dict to only include base model weights if skip_base_model_weights is True + def _filtered_state_dict_iterator(): + """Iterator that yields only base model weights when skip_base_model_weights is True.""" + for name, tensor in self.model.state_dict().items(): + # Skip base model weights if skip_base_model_weights is True + if self._is_base_model_weight(name) and not refit_base_model_weights: + continue + if self._is_lora_weight(name) and not refit_lora_weights: + continue + yield (name, tensor) + packed_broadcast_producer( - iterator=iter(self.model.state_dict().items()), + iterator=_filtered_state_dict_iterator(), group=self.model_update_group, src=0, post_iter_func=dtensor_post_iter_func, diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 66767822e6..caece09081 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -2118,7 +2118,11 @@ def _iter_params_with_optional_kv_scales( @torch.no_grad() @wrap_with_nvtx_name("megatron_policy_worker/stream_weights_via_ipc_zmq") def stream_weights_via_ipc_zmq( - self, buffer_size_bytes: int = 0, kv_scales: Optional[dict[str, float]] = None + self, + buffer_size_bytes: int = 0, + kv_scales: Optional[dict[str, float]] = None, + refit_base_model_weights: bool = True, + refit_lora_weights: bool = False, ) -> None: """Stream model weights to peer process via ZMQ IPC socket.""" self.maybe_init_zmq() diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index e39fef12d4..1287b1641c 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -35,7 +35,7 @@ from nemo_rl.models.generation.vllm.vllm_worker_async import ( _replace_prefix_tokens, ) -from nemo_rl.models.policy import PolicyConfig +from nemo_rl.models.policy import LoRAConfig, PolicyConfig from nemo_rl.models.policy.lm_policy import Policy model_name = "Qwen/Qwen3-0.6B" @@ -127,6 +127,19 @@ "generation": deepcopy(basic_vllm_test_config), } +basic_lora_test_config: LoRAConfig = { + "enabled": True, + "target_modules": [], + "exclude_modules": [], + "match_all_linear": True, + "dim": 8, + "alpha": 32, + "dropout": 0.0, + "dropout_position": "post", + "lora_A_init": "xavier", + "use_triton": True, +} + def get_basic_megatron_test_config( tp: int = 1, @@ -2518,3 +2531,100 @@ def test_vllm_megatron_weight_update_with_packing(cluster, test_input_data): megatron_policy.shutdown() if vllm_generation: vllm_generation.shutdown() + + +# ANSI color codes +CYAN = "\033[96m" +GREEN = "\033[92m" +YELLOW = "\033[93m" +BLUE = "\033[94m" +MAGENTA = "\033[95m" +RED = "\033[91m" +BOLD = "\033[1m" +RESET = "\033[0m" + + +def test_vllm_lora_refit_sync_colocated(cluster, tokenizer): + """Test vLLM LoRA refit with sync engine and colocated setup.""" + vllm_config = deepcopy(basic_vllm_test_config) + vllm_config["vllm_cfg"]["lora_cfg"] = deepcopy(basic_lora_test_config) + vllm_config["vllm_cfg"]["lora_cfg"]["enabled"] = True + vllm_config["vllm_cfg"]["async_engine"] = False + vllm_config = configure_generation_config(vllm_config, tokenizer) + + dtensor_config = deepcopy(basic_dtensor_test_config) + dtensor_config["dtensor_cfg"]["_v2"] = True + dtensor_config["dtensor_cfg"]["lora_cfg"] = deepcopy(basic_lora_test_config) + dtensor_config["dtensor_cfg"]["lora_cfg"]["enabled"] = True + + print(f"\n{CYAN}{BOLD}{'=' * 80}\n>>> CREATING DTENSOR POLICY\n{'=' * 80}{RESET}") + lm_policy = Policy(cluster, dtensor_config, tokenizer) + + print(f"\n{CYAN}{BOLD}{'=' * 80}\n>>> CREATING VLLM POLICY\n{'=' * 80}{RESET}") + vllm_policy = VllmGeneration(cluster, vllm_config) + vllm_policy.finish_generation() + + print(f"\n{YELLOW}{BOLD}{'=' * 80}\n>>> PREPARING REFIT INFO\n{'=' * 80}{RESET}") + state_dict_info = lm_policy.prepare_refit_info() + vllm_policy.prepare_refit_info(state_dict_info) + # take it outside statistics to get clean peak memory during refit + lm_policy.offload_before_refit() + + print( + f"\n{YELLOW}{BOLD}{'=' * 80}\n>>> STARTING VLLM POLICY REFIT BASE MODEL WEIGHTS\n{'=' * 80}{RESET}" + ) + refit_policy_generation( + lm_policy, + vllm_policy, + vllm_config["colocated"]["enabled"], + _refit_buffer_size_gb=1.5, + refit_base_model_weights=True, + refit_lora_weights=True, + ) + + print(f"\n{YELLOW}{BOLD}{'=' * 80}\n>>> GETTING LORA LAYERS\n{'=' * 80}{RESET}") + lora_layers = vllm_policy.get_lora_layers()[0][0] + for layer in lora_layers: + for a_weight in layer["a_weights"]: + assert torch.all(a_weight == 1) + for b_weight in layer["b_weights"]: + assert torch.all(b_weight == 0) + + +def test_vllm_lora_generation(cluster, tokenizer): + """Test vLLM LoRA refit with sync engine and colocated setup.""" + vllm_config = deepcopy(basic_vllm_test_config) + vllm_config["vllm_cfg"]["lora_cfg"] = deepcopy(basic_lora_test_config) + vllm_config["vllm_cfg"]["lora_cfg"]["enabled"] = True + vllm_config["vllm_cfg"]["async_engine"] = False + vllm_config = configure_generation_config(vllm_config, tokenizer) + + print(f"\n{CYAN}{BOLD}{'=' * 80}\n>>> CREATING VLLM POLICY\n{'=' * 80}{RESET}") + vllm_policy = VllmGeneration(cluster, vllm_config) + vllm_policy.prepare_for_generation() + + print(f"\n{CYAN}{BOLD}{'=' * 80}\n>>> GENERATING TEXT\n{'=' * 80}{RESET}") + prompts = [ + "What is the largest number, all of whose digits are 1 or 4, and whose digits add up to 12?" + ] + test_tokenizer = get_tokenizer({"name": model_name}) + tokenized = test_tokenizer( + prompts, + padding=True, + truncation=True, + max_length=256, + return_tensors="pt", + padding_side="right", + ) + test_input_data = BatchedDataDict( + { + "input_ids": tokenized["input_ids"], + "input_lengths": tokenized["attention_mask"].sum(dim=1).to(torch.int32), + } + ) + outputs = vllm_policy.generate(test_input_data, greedy=True) + output_ids = outputs["output_ids"] + generated_texts = test_tokenizer.batch_decode(output_ids, skip_special_tokens=True) + print( + f"\n{CYAN}{BOLD}{'=' * 80}\n>>> GENERATED TEXT:\n{generated_texts}\n{'=' * 80}{RESET}" + ) diff --git a/tests/unit/models/policy/test_dtensor_worker_v2.py b/tests/unit/models/policy/test_dtensor_worker_v2.py index 4e9f33f99e..f7e42f7eee 100644 --- a/tests/unit/models/policy/test_dtensor_worker_v2.py +++ b/tests/unit/models/policy/test_dtensor_worker_v2.py @@ -456,3 +456,29 @@ def test_dtensor_v2_mixed_precision_training_and_logprobs( assert worker_info is not None, "Should get worker info" finally: policy.shutdown() + + +def test_dtensor_worker_v2_generate_lora_weights(two_gpu_virtual_cluster): + """Test that dtensor worker v2 can generate with LoRA weights.""" + config = create_test_config( + model_name="Qwen/Qwen3-0.6B", + dtensor_v2=True, + ) + lora_config = { + "enabled": True, + "target_modules": [], + "exclude_modules": [], + "match_all_linear": True, + "dim": 8, + "alpha": 32, + "dropout": 0.0, + "dropout_position": "post", + "lora_A_init": "xavier", + } + config["dtensor_cfg"]["lora_cfg"] = lora_config + lm_policy = Policy( + cluster=two_gpu_virtual_cluster, + config=config, + tokenizer=get_tokenizer(config["tokenizer"]), + ) + lm_policy.stream_weights_via_ipc_zmq(buffer_size_bytes=1024**3, kv_scales=None) From 6113b3f57261b56a710a69603941e955e826e43e Mon Sep 17 00:00:00 2001 From: ruit Date: Tue, 23 Dec 2025 19:50:18 -0800 Subject: [PATCH 2/6] update paraname map from vllm lora manager Signed-off-by: ruit --- nemo_rl/algorithms/grpo.py | 190 ++++++++++++++++++ nemo_rl/models/generation/lora.py | 34 +++- .../models/generation/vllm/vllm_backend.py | 48 ++++- .../models/generation/vllm/vllm_generation.py | 6 +- nemo_rl/models/generation/vllm/vllm_worker.py | 32 ++- nemo_rl/models/policy/lm_policy.py | 3 + nemo_rl/utils/logger.py | 18 ++ .../models/generation/test_vllm_generation.py | 115 +++++++---- 8 files changed, 374 insertions(+), 72 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index f6fd03265e..de244256c8 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1026,6 +1026,156 @@ def _perform_refit_weights( # =============================================================================== +def dump_lora_layers_metadata( + lora_layers_weights, + *, + dump_path: Optional[str] = None, +) -> Optional[str]: + """Write LoRA layer metadata (layer name, A/B weight shapes and dtypes) to a JSON file. + + Path priority: + 1) explicit dump_path + 2) environment variable NRL_LORA_LAYERS_DUMP + 3) environment variable NRL_OUTPUT_DIR or current working directory + timestamped filename + Returns the written path on success; None on failure (and prints a warning). + """ + try: + import json + import os + import time + + if dump_path is None: + dump_path = os.environ.get("NRL_LORA_LAYERS_DUMP") + if dump_path is None: + default_dir = os.environ.get("NRL_OUTPUT_DIR") or os.getcwd() + dump_path = os.path.join( + default_dir, f"lora_layers_{int(time.time())}.json" + ) + + # Flatten potential DP-nested structure + flattened = [] + if ( + isinstance(lora_layers_weights, list) + and len(lora_layers_weights) > 0 + and isinstance(lora_layers_weights[0], list) + ): + for sub in lora_layers_weights: + if sub: + flattened.extend(sub) + else: + flattened = lora_layers_weights + + def _shapes_and_dtypes(tensors): + result = [] + for t in tensors or []: + try: + shape = tuple(int(x) for x in getattr(t, "shape", ())) + dtype = str(getattr(t, "dtype", "unknown")) + result.append({"shape": shape, "dtype": dtype}) + except Exception: + result.append({"shape": None, "dtype": "unknown"}) + return result + + sanitized = [] + for item in flattened or []: + if not isinstance(item, dict): + continue + sanitized.append( + { + "name": item.get("name"), + "a_shapes": _shapes_and_dtypes(item.get("a_weights")), + "b_shapes": _shapes_and_dtypes(item.get("b_weights")), + } + ) + + os.makedirs(os.path.dirname(dump_path) or ".", exist_ok=True) + with open(dump_path, "w") as f: + json.dump({"layers": sanitized}, f, indent=2) + print(f"[INFO] LoRA layer metadata written to {dump_path}") + return dump_path + except Exception as e: + print(f"[WARN] Failed to dump LoRA layer metadata: {e}") + return None + + +def dump_lora_layers_tensors( + lora_layers_weights, + *, + dump_path: Optional[str] = None, + cast_dtype: Optional[str] = None, +) -> Optional[str]: + """Write full LoRA layer weights to a file (torch.save). + + - Supports single dict or list[list[dict]] / list[dict] inputs. + - Moves tensors to CPU by default to avoid CUDA deserialization constraints. + - Optional cast_dtype: "float32" | "bfloat16" | "float16" + Returns the written path on success; None on failure. + """ + try: + import os + import time + + import torch + + if dump_path is None: + dump_path = os.environ.get("NRL_LORA_LAYERS_TENSORS") + if dump_path is None: + default_dir = os.environ.get("NRL_OUTPUT_DIR") or os.getcwd() + dump_path = os.path.join(default_dir, f"lora_layers_{int(time.time())}.pt") + + # Normalize to a flat list[dict] + if isinstance(lora_layers_weights, dict): + flattened = [lora_layers_weights] + elif ( + isinstance(lora_layers_weights, list) + and len(lora_layers_weights) > 0 + and isinstance(lora_layers_weights[0], list) + ): + flattened = [] + for sub in lora_layers_weights: + if sub: + flattened.extend(sub) + else: + flattened = lora_layers_weights + + dtype_map = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + "float16": torch.float16, + } + target_dtype = dtype_map.get(cast_dtype.lower(), None) if cast_dtype else None + + def _to_cpu_and_cast(tensor): + if not hasattr(tensor, "detach"): + return tensor + t = tensor.squeeze(0).squeeze(0).detach().to("cpu") + if target_dtype is not None: + t = t.to(target_dtype) + return t + + sanitized = [] + for item in flattened or []: + if not isinstance(item, dict): + continue + a_weights = [_to_cpu_and_cast(t) for t in (item.get("a_weights") or [])] + b_weights = [_to_cpu_and_cast(t) for t in (item.get("b_weights") or [])] + sanitized.append( + { + "name": item.get("name"), + "a_weights": a_weights, + "b_weights": b_weights, + } + ) + + os.makedirs(os.path.dirname(dump_path) or ".", exist_ok=True) + torch.save({"layers": sanitized}, dump_path) + print(f"[INFO] LoRA layer tensors written to {dump_path}") + return dump_path + except Exception as e: + print(f"[WARN] Failed to dump LoRA layer tensors: {e}") + return None + + def grpo_train( policy: ColocatablePolicyInterface, policy_generation: Optional[GenerationInterface], @@ -1203,6 +1353,39 @@ def grpo_train( policy.offload_after_refit() # unload optimizer to make space for generation policy_generation.prepare_for_generation() + # lora_layers_weights = policy_generation.get_lora_layers()[0][0] + # try: + # dump_lora_layers_metadata(lora_layers_weights, dump_path="/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_nemorl/users/ruit/RL/logs/lora_layers_metadata.json") + # print(f"[INFO] LoRA layer metadata written to /lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_nemorl/users/ruit/RL/logs/lora_layers_metadata.json") + # dump_lora_layers_tensors(lora_layers_weights, dump_path="/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_nemorl/users/ruit/RL/logs/lora_layers_tensors.pt") + # print(f"[INFO] LoRA layer tensors written to /lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_nemorl/users/ruit/RL/logs/lora_layers_tensors.pt") + # except Exception as e: + # print(f"[WARN] Failed to dump LoRA layer metadata: {e}") + + # print(f"[INFO] Initializing checkpoint...") + # checkpoint_path = checkpointer.init_tmp_checkpoint( + # total_steps + 1, grpo_save_state, master_config + # ) + # policy.save_checkpoint( + # weights_path=os.path.join( + # checkpoint_path, "policy", "weights" + # ), + # optimizer_path=os.path.join( + # checkpoint_path, "policy", "optimizer" + # ), + # tokenizer_path=os.path.join( + # checkpoint_path, "policy", "tokenizer" + # ), + # checkpointing_cfg=master_config["checkpointing"], + # ) + # torch.save( + # dataloader.state_dict(), + # os.path.join(checkpoint_path, "train_dataloader.pt"), + # ) + # checkpointer.finalize_checkpoint(checkpoint_path) + # print(f"[INFO] Checkpoint saved to {checkpoint_path}") + # exit() + dynamic_sampling_num_gen_batches += 1 with timer.time("generation"): # Clear vLLM logger metrics for each generation step @@ -1342,8 +1525,14 @@ def grpo_train( loss_multiplier[truncated] = 0 repeated_batch["loss_multiplier"] = loss_multiplier # Add loss mask and advantages to each message in LLMMessageLogType + # print_colored(f"length of message_log: {len(repeated_batch['message_log'])}", BLUE) + # print_colored(f"repeated_batch['message_log']: {repeated_batch['message_log']}", BLUE) for i, message_log in enumerate(repeated_batch["message_log"]): + # print_colored(f"length of message_log {i}: {len(message_log)}", BLUE) + # print_colored(f"message_log {i}: {message_log}", BLUE) for j, message in enumerate(message_log): + # print_colored(f"length of message {i}, {j}: {len(message)}", BLUE) + # print_colored(f"message {i}, {j}: {message}", BLUE) if message["role"] == "assistant": message["token_loss_mask"] = torch.ones_like( message["token_ids"] @@ -1353,6 +1542,7 @@ def grpo_train( message["token_ids"] ) if "generation_logprobs" not in message: + # print_colored(f"Set Generation logprobs to zeros for message {i}, {j}") message["generation_logprobs"] = torch.zeros_like( message["token_ids"], dtype=torch.float32 ) diff --git a/nemo_rl/models/generation/lora.py b/nemo_rl/models/generation/lora.py index ae058857ad..0f6b1f655c 100644 --- a/nemo_rl/models/generation/lora.py +++ b/nemo_rl/models/generation/lora.py @@ -15,7 +15,9 @@ from typing import Any, Optional +from vllm.lora.peft_helper import PEFTHelper from vllm.lora.request import LoRARequest +from vllm.lora.utils import get_adapter_absolute_path class LoRARequestWithCfgAndWeights(LoRARequest): @@ -38,14 +40,18 @@ def patched_load_adapter(self, lora_request: LoRARequestWithCfgAndWeights): expected_lora_modules = set(expected_lora_lst) lora_weights = None - from vllm.lora.peft_helper import PEFTHelper - if isinstance(lora_request, LoRARequestWithCfgAndWeights): lora_cfg = lora_request.lora_cfg lora_weights = lora_request.lora_weights peft_helper = PEFTHelper.from_dict(lora_cfg) else: - raise ValueError(f"Unsupported LoRA request type: {type(lora_request)}") + lora_path = get_adapter_absolute_path(lora_request.lora_path) + + peft_helper = PEFTHelper.from_local_dir( + lora_path, + self.max_position_embeddings, + lora_request.tensorizer_config_dict, + ) # Validates the LoRA configuration against requirements before # loading weights, throwing an exception if validation fails. @@ -55,7 +61,6 @@ def patched_load_adapter(self, lora_request: LoRARequestWithCfgAndWeights): # to ensure correct loading of lora weights. model = self._adapter_manager.model hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None) - print(f"hf_to_vllm_mapper in lora.patched_load_adapter: {hf_to_vllm_mapper}") if isinstance(lora_request, LoRARequestWithCfgAndWeights): lora = self._lora_model_cls.from_lora_tensors( lora_model_id=lora_request.lora_int_id, @@ -70,9 +75,20 @@ def patched_load_adapter(self, lora_request: LoRARequestWithCfgAndWeights): embedding_padding_modules=self.embedding_padding_modules, weights_mapper=hf_to_vllm_mapper, ) - else: - raise ValueError(f"Unsupported LoRA request type: {type(lora_request)}") + lora = self._lora_model_cls.from_local_checkpoint( + lora_path, + expected_lora_modules, + peft_helper=peft_helper, + lora_model_id=lora_request.lora_int_id, + device="cpu", + dtype=self.lora_config.lora_dtype, + target_embedding_padding=self.vocab_size + + self.lora_config.lora_extra_vocab_size, + embedding_modules=self.embedding_modules, + embedding_padding_modules=self.embedding_padding_modules, + weights_mapper=hf_to_vllm_mapper, + ) except FileNotFoundError as e: # FileNotFoundError should be raised if both @@ -104,13 +120,9 @@ def apply_lora_patches(): setattr(LRUCacheWorkerLoRAManager, "_load_adapter", patched_load_adapter) -lora_int_id = 0 - - # Note: Not sure put it here or in nemo_rl/models/generation/vllm/utils.py def get_vllm_lora_metadata() -> dict[str, Any]: - global lora_int_id - lora_int_id += 1 # Can be any unique id exclude 0 + lora_int_id = 1 # Can be any unique id exclude 0 lora_name = f"{lora_int_id}" lora_path = "dummy_lora_path" return { diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 6e06372f00..7cefa3e512 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -23,6 +23,7 @@ calculate_aligned_size, rebuild_cuda_tensor_from_ipc, ) +from nemo_rl.utils.logger import RED, print_colored from nemo_rl.utils.nsys import wrap_with_nvtx_name from nemo_rl.utils.packed_tensor import packed_broadcast_consumer @@ -95,6 +96,7 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: e.g. {tensor_name: (shape, dtype)} """ self.state_dict_info = state_dict_info # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored + # self.vllm_state_dict_keys = self.model_runner.model.state_dict().keys() def _maybe_process_fp8_kv_cache(self) -> None: """Process weights after loading for FP8 KV cache (static scales).""" @@ -125,17 +127,41 @@ def _maybe_process_fp8_kv_cache(self) -> None: target_device, ) + def map_param_name(self, param_name: str) -> str: + lora_mgr = self.model_runner.model.lora_manager + supported_modules = lora_mgr.supported_lora_modules + packed_modules_mapping = lora_mgr.packed_modules_mapping + + parts = param_name.split(".") + if len(parts) < 2: + return param_name + + base_name = ".".join(parts[:-2]) # prefix + module_name = parts[-2] # e.g. q_proj/k_proj/v_proj/gate_proj/up_proj/... + field_name = parts[-1] # weight/bias + + resolved_module_name = module_name + for packed_name, member_names in packed_modules_mapping.items(): + if module_name in member_names: + resolved_module_name = packed_name + break + + # use resolved_module_name for checking, but return the original module_name + if resolved_module_name in supported_modules: + if base_name != "": + return f"{base_name}.{module_name}.base_layer.{field_name}" + else: + return f"{module_name}.base_layer.{field_name}" + + return param_name + def _apply_weight_name_mapping( self, weights: list[tuple[str, torch.Tensor]] ) -> list[tuple[str, torch.Tensor]]: """Apply weight name mapping if LoRA is enabled.""" new_weights = [] for name, w in weights: - new_name = name - if ".self_attn." in name and name.endswith("_proj.weight"): - new_name = name.replace("_proj.weight", "_proj.base_layer.weight") - if ".mlp." in name and name.endswith("_proj.weight"): - new_name = name.replace("_proj.weight", "_proj.base_layer.weight") + new_name = self.map_param_name(name) new_weights.append((new_name, w)) return new_weights @@ -154,6 +180,10 @@ def update_weights_via_ipc_zmq( buffer = None weights = None + print_colored( + f"lora_config in update_weights_via_ipc_zmq: {self.model_runner.vllm_config.lora_config}", + RED, + ) try: self.maybe_init_zmq() while True: @@ -211,6 +241,7 @@ def update_weights_via_ipc_zmq( ): weights = self._apply_weight_name_mapping(weights) self.model_runner.model.load_weights(weights=weights) + print_colored("updated base model weights", RED) elif refit_lora_weights: assert lora_config, ( "lora_config is not provided, can not refit lora weights" @@ -244,6 +275,7 @@ def update_weights_via_ipc_zmq( ), ) self.add_lora(lora_request=lora_request) + print_colored("updated lora weights", RED) else: raise ValueError( "refit_base_model_weights and refit_lora_weights cannot be both False" @@ -337,9 +369,3 @@ def start_gpu_profiling(self) -> None: def stop_gpu_profiling(self) -> None: """Stop GPU profiling.""" torch.cuda.profiler.stop() - - def get_lora_counts(self) -> int: - """Get the number of LoRA layers from the vLLM engine.""" - results = self.list_loras() - print(f"Results: {results}") - return results diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 80db9aa75d..751c02db4b 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -927,8 +927,8 @@ def get_lora_layers(self) -> list[dict[str, Any]]: results = ray.get(futures) return results - def get_lora_counts(self): - """Get the number of LoRA from the vLLM engine.""" - futures = self.worker_group.run_all_workers_single_data("get_lora_counts") + def get_model_state_dict(self) -> dict[str, Any]: + """Get the model state dict from the vLLM engine.""" + futures = self.worker_group.run_all_workers_single_data("get_model_state_dict") results = ray.get(futures) return results diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index 1a46e87760..7ac3071b48 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -34,6 +34,7 @@ from nemo_rl.models.generation.vllm.utils import format_prompt_for_vllm_generation from nemo_rl.models.huggingface.common import ModelFlag from nemo_rl.models.policy.utils import is_vllm_v1_engine_enabled +from nemo_rl.utils.logger import print_colored from nemo_rl.utils.nsys import wrap_with_nvtx_name @@ -390,10 +391,12 @@ def _patch_vllm_vit_flash_attn_backend(): # disable quantization vllm_kwargs["hf_overrides"]["quantization_config"] = {} # Lora is enabled, add it to the vllm kwargs + self.lora_enabled = False if self.lora_cfg is not None and self.lora_cfg["enabled"]: from nemo_rl.models.generation.lora import apply_lora_patches apply_lora_patches() + self.lora_enabled = True vllm_kwargs["enable_lora"] = True vllm_kwargs["max_loras"] = 1 # only support one lora adapter vllm_kwargs["max_lora_rank"] = self.lora_cfg["dim"] @@ -569,7 +572,20 @@ def generate( assert self.llm is not None, ( "Attempting to generate with either an uninitialized vLLM or non-model-owner" ) - outputs = self.llm.generate(prompts, sampling_params) + + lora_req = None + if self.lora_enabled: + # print_colored(f"list_lora in generate: {self.llm.llm_engine.list_loras()}") + from vllm.lora.request import LoRARequest + + from nemo_rl.models.generation.lora import get_vllm_lora_metadata + + lora_metadata = get_vllm_lora_metadata() + lora_req = LoRARequest( + **lora_metadata, + ) + print_colored(f"lora_req in generate: {lora_req}") + outputs = self.llm.generate(prompts, sampling_params, lora_request=lora_req) # Process the outputs - but preserve the original input padding structure output_ids_list = [] @@ -804,6 +820,7 @@ def _get_lora_layers(self): details = [] for name, module in model.named_modules(): + print_colored(f"name: {name}") if isinstance(module, BaseLinearLayerWithLoRA): a_shapes = [tuple(t.shape) for t in module.lora_a_stacked] b_shapes = [tuple(t.shape) for t in module.lora_b_stacked] @@ -821,9 +838,16 @@ def _get_lora_layers(self): results = self.llm.collective_rpc(_get_lora_layers) return results - def get_lora_counts(self) -> int: - """Get the number of LoRA layers from the vLLM engine.""" - results = self.llm.collective_rpc("get_lora_counts") + def get_model_state_dict(self) -> dict[str, torch.Tensor]: + """Get the model state dict from the vLLM engine.""" + + def _get_model_state_dict(self): + model = self.get_model() + if model is None: + return {} + return model.state_dict() + + results = self.llm.collective_rpc(_get_model_state_dict) return results def reset_prefix_cache(self): diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 19ca457e86..521b9945a9 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -114,6 +114,9 @@ def __init__( use_v2 = config.get("dtensor_cfg", {}).get("_v2", False) lora_cfg = config.get("dtensor_cfg", {}).get("lora_cfg", {}) self.lora_enabled = lora_cfg.get("enabled", False) + from nemo_rl.utils.logger import print_colored + + print_colored(f"LORA ENABLED in lm_policy: {self.lora_enabled}") if use_v2: worker_builder_cls = "nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2" diff --git a/nemo_rl/utils/logger.py b/nemo_rl/utils/logger.py index ed431986dc..b43a62df92 100644 --- a/nemo_rl/utils/logger.py +++ b/nemo_rl/utils/logger.py @@ -1460,3 +1460,21 @@ def get_next_experiment_dir(base_log_dir: str) -> str: os.makedirs(new_log_dir, exist_ok=True) return new_log_dir + + +# ANSI color codes +CYAN = "\033[96m" +GREEN = "\033[92m" +YELLOW = "\033[93m" +BLUE = "\033[94m" +MAGENTA = "\033[95m" +RED = "\033[91m" +BOLD = "\033[1m" +RESET = "\033[0m" + + +def print_colored(text: str, color: str = YELLOW): + line = "=" * 80 + print(f"\n{color}{BOLD}{line}{RESET}") + print(f"{color}{BOLD}{text}{RESET}") + print(f"{color}{BOLD}{line}{RESET}\n") diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 1287b1641c..4cbc7b2bdc 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -38,7 +38,9 @@ from nemo_rl.models.policy import LoRAConfig, PolicyConfig from nemo_rl.models.policy.lm_policy import Policy -model_name = "Qwen/Qwen3-0.6B" +# model_name = "Qwen/Qwen3-0.6B" +# model_name = "Qwen/Qwen2.5-1.5B" +model_name = "unsloth/Llama-3.2-1B-Instruct" # Define basic vLLM test config basic_vllm_test_config: VllmConfig = { "backend": "vllm", @@ -2533,15 +2535,7 @@ def test_vllm_megatron_weight_update_with_packing(cluster, test_input_data): vllm_generation.shutdown() -# ANSI color codes -CYAN = "\033[96m" -GREEN = "\033[92m" -YELLOW = "\033[93m" -BLUE = "\033[94m" -MAGENTA = "\033[95m" -RED = "\033[91m" -BOLD = "\033[1m" -RESET = "\033[0m" +from nemo_rl.utils.logger import BLUE, print_colored def test_vllm_lora_refit_sync_colocated(cluster, tokenizer): @@ -2557,53 +2551,52 @@ def test_vllm_lora_refit_sync_colocated(cluster, tokenizer): dtensor_config["dtensor_cfg"]["lora_cfg"] = deepcopy(basic_lora_test_config) dtensor_config["dtensor_cfg"]["lora_cfg"]["enabled"] = True - print(f"\n{CYAN}{BOLD}{'=' * 80}\n>>> CREATING DTENSOR POLICY\n{'=' * 80}{RESET}") + print_colored("CREATING DTENSOR POLICY", BLUE) lm_policy = Policy(cluster, dtensor_config, tokenizer) - print(f"\n{CYAN}{BOLD}{'=' * 80}\n>>> CREATING VLLM POLICY\n{'=' * 80}{RESET}") + print_colored("CREATING VLLM POLICY", BLUE) vllm_policy = VllmGeneration(cluster, vllm_config) vllm_policy.finish_generation() - print(f"\n{YELLOW}{BOLD}{'=' * 80}\n>>> PREPARING REFIT INFO\n{'=' * 80}{RESET}") + print_colored("PREPARING REFIT INFO", BLUE) state_dict_info = lm_policy.prepare_refit_info() vllm_policy.prepare_refit_info(state_dict_info) # take it outside statistics to get clean peak memory during refit lm_policy.offload_before_refit() - print( - f"\n{YELLOW}{BOLD}{'=' * 80}\n>>> STARTING VLLM POLICY REFIT BASE MODEL WEIGHTS\n{'=' * 80}{RESET}" - ) + print_colored("STARTING VLLM POLICY REFIT BASE MODEL WEIGHTS", BLUE) refit_policy_generation( lm_policy, vllm_policy, vllm_config["colocated"]["enabled"], - _refit_buffer_size_gb=1.5, + _refit_buffer_size_gb=3, refit_base_model_weights=True, - refit_lora_weights=True, + refit_lora_weights=vllm_config["vllm_cfg"]["lora_cfg"]["enabled"], ) - print(f"\n{YELLOW}{BOLD}{'=' * 80}\n>>> GETTING LORA LAYERS\n{'=' * 80}{RESET}") - lora_layers = vllm_policy.get_lora_layers()[0][0] - for layer in lora_layers: - for a_weight in layer["a_weights"]: - assert torch.all(a_weight == 1) - for b_weight in layer["b_weights"]: - assert torch.all(b_weight == 0) - - -def test_vllm_lora_generation(cluster, tokenizer): - """Test vLLM LoRA refit with sync engine and colocated setup.""" - vllm_config = deepcopy(basic_vllm_test_config) - vllm_config["vllm_cfg"]["lora_cfg"] = deepcopy(basic_lora_test_config) - vllm_config["vllm_cfg"]["lora_cfg"]["enabled"] = True - vllm_config["vllm_cfg"]["async_engine"] = False - vllm_config = configure_generation_config(vllm_config, tokenizer) - - print(f"\n{CYAN}{BOLD}{'=' * 80}\n>>> CREATING VLLM POLICY\n{'=' * 80}{RESET}") - vllm_policy = VllmGeneration(cluster, vllm_config) - vllm_policy.prepare_for_generation() - - print(f"\n{CYAN}{BOLD}{'=' * 80}\n>>> GENERATING TEXT\n{'=' * 80}{RESET}") + # vllm_model_state_dict = vllm_policy.get_model_state_dict()[0][0] + # print_colored(f"VLLM MODEL STATE DICT: {vllm_model_state_dict.keys()}", BLUE) + # from transformers import AutoModel + + # model = AutoModel.from_pretrained(model_name) + # model_state_dict = model.state_dict() + + # for name, vllm_tensor in vllm_model_state_dict.items(): + # name = name.replace("model.", "") + # model_tensor = model_state_dict[name] + # print_colored( + # f"NAME: {name}, vllm type : {vllm_tensor.dtype}, model type: {model_tensor.dtype}", + # BLUE, + # ) + # vllm_tensor = vllm_tensor.to("cpu") + # model_tensor = model_tensor.to(vllm_tensor.dtype).to("cpu") + # if not torch.allclose(vllm_tensor, model_tensor): + # print_colored(f"Tensor {name} is not close", RED) + # print_colored(f"MODEL TENSOR: {model_tensor.shape}", RED) + # print_colored(f"VLLM TENSOR: {vllm_tensor.shape}", RED) + # assert False, f"Tensor {name} is not close" + + print_colored("GENERATING TEXT", BLUE) prompts = [ "What is the largest number, all of whose digits are 1 or 4, and whose digits add up to 12?" ] @@ -2622,9 +2615,45 @@ def test_vllm_lora_generation(cluster, tokenizer): "input_lengths": tokenized["attention_mask"].sum(dim=1).to(torch.int32), } ) + vllm_policy.prepare_for_generation() outputs = vllm_policy.generate(test_input_data, greedy=True) output_ids = outputs["output_ids"] generated_texts = test_tokenizer.batch_decode(output_ids, skip_special_tokens=True) - print( - f"\n{CYAN}{BOLD}{'=' * 80}\n>>> GENERATED TEXT:\n{generated_texts}\n{'=' * 80}{RESET}" - ) + print_colored(f"GENERATED TEXT: {generated_texts}") + + +def test_vllm_lora_generation(cluster, tokenizer): + """Test vLLM LoRA refit with sync engine and colocated setup.""" + vllm_config = deepcopy(basic_vllm_test_config) + vllm_config["vllm_cfg"]["lora_cfg"] = deepcopy(basic_lora_test_config) + vllm_config["vllm_cfg"]["lora_cfg"]["enabled"] = True + vllm_config["vllm_cfg"]["async_engine"] = False + vllm_config = configure_generation_config(vllm_config, tokenizer) + + print_colored("CREATING VLLM POLICY") + vllm_policy = VllmGeneration(cluster, vllm_config) + vllm_policy.prepare_for_generation() + + # print_colored("GENERATING TEXT") + # prompts = [ + # "What is the largest number, all of whose digits are 1 or 4, and whose digits add up to 12?" + # ] + # test_tokenizer = get_tokenizer({"name": model_name}) + # tokenized = test_tokenizer( + # prompts, + # padding=True, + # truncation=True, + # max_length=256, + # return_tensors="pt", + # padding_side="right", + # ) + # test_input_data = BatchedDataDict( + # { + # "input_ids": tokenized["input_ids"], + # "input_lengths": tokenized["attention_mask"].sum(dim=1).to(torch.int32), + # } + # ) + # outputs = vllm_policy.generate(test_input_data, greedy=True) + # output_ids = outputs["output_ids"] + # generated_texts = test_tokenizer.batch_decode(output_ids, skip_special_tokens=True) + # print_colored(f"GENERATED TEXT: {generated_texts}") From 4354c9c6e2f6e0460874edf1a625093fff652027 Mon Sep 17 00:00:00 2001 From: ruit Date: Sat, 27 Dec 2025 21:21:34 -0800 Subject: [PATCH 3/6] support async and non-colocated Signed-off-by: ruit --- examples/configs/grpo_math_1B.yaml | 2 +- .../grpo-qwen3-8B-base-1n8g-fsdp2-lora.yaml | 27 ++ nemo_rl/algorithms/grpo.py | 44 +++- nemo_rl/models/generation/interfaces.py | 4 +- nemo_rl/models/generation/lora.py | 2 +- .../models/generation/vllm/vllm_backend.py | 245 ++++++++++-------- .../models/generation/vllm/vllm_generation.py | 6 +- nemo_rl/models/generation/vllm/vllm_worker.py | 18 +- .../generation/vllm/vllm_worker_async.py | 18 +- nemo_rl/models/policy/interfaces.py | 5 +- nemo_rl/models/policy/lm_policy.py | 9 +- .../policy/workers/dtensor_policy_worker.py | 8 +- .../workers/dtensor_policy_worker_v2.py | 23 +- .../policy/workers/megatron_policy_worker.py | 6 +- nemo_rl/utils/weights.py | 12 + .../llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh | 41 +++ 16 files changed, 309 insertions(+), 161 deletions(-) create mode 100644 examples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yaml create mode 100644 nemo_rl/utils/weights.py create mode 100644 tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 46982537ad..f0388590bc 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -92,7 +92,7 @@ policy: custom_parallel_plan: null # LoRA (Low-Rank Adaptation) Configuration lora_cfg: - enabled: True # Set to True to enable LoRA fine-tuning + enabled: False # Set to True to enable LoRA fine-tuning target_modules: [] # List of module names to apply LoRA (empty list with match_all_linear=true applies to all linear layers) exclude_modules: [] # List of module names to exclude from LoRA match_all_linear: true # If True, applies LoRA to all linear layers (overrides target_modules) diff --git a/examples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yaml b/examples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yaml new file mode 100644 index 0000000000..0d7698471f --- /dev/null +++ b/examples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yaml @@ -0,0 +1,27 @@ +defaults: ../../grpo_math_1B.yaml +grpo: + val_at_start: true +checkpointing: + checkpoint_dir: results/grpo-qwen3-8B-base-1n8g-fsdp2-lora +policy: + model_name: Qwen/Qwen3-8B-Base + max_total_sequence_length: 4096 + dtensor_cfg: + activation_checkpointing: true + lora_cfg: + enabled: True + dim: 256 + alpha: 512 + sequence_packing: + enabled: false +data: + shuffle: false +logger: + log_dir: logs/grpo-qwen3-8B-base-1n8g-fsdp2-lora + wandb_enabled: true + tensorboard_enabled: true + wandb: + project: nemo-rl + name: grpo-qwen3-8B-base-1n8g-fsdp2-lora +cluster: + gpus_per_node: 8 diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index de244256c8..298fcf52c3 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -974,8 +974,15 @@ def _perform_refit_weights( update_success = all(result for result in results if result is not None) else: # update weights through nccl - futures_train = policy.broadcast_weights_for_collective(kv_scales=kv_scales) - futures_inference = policy_generation.update_weights_from_collective() + futures_train = policy.broadcast_weights_for_collective( + kv_scales=kv_scales, + refit_base_model_weights=refit_base_model_weights, + refit_lora_weights=refit_lora_weights, + ) + futures_inference = policy_generation.update_weights_from_collective( + refit_base_model_weights=refit_base_model_weights, + refit_lora_weights=refit_lora_weights, + ) # wait for all futures to complete ray.get(futures_train) results = ray.get(futures_inference) @@ -1525,14 +1532,8 @@ def grpo_train( loss_multiplier[truncated] = 0 repeated_batch["loss_multiplier"] = loss_multiplier # Add loss mask and advantages to each message in LLMMessageLogType - # print_colored(f"length of message_log: {len(repeated_batch['message_log'])}", BLUE) - # print_colored(f"repeated_batch['message_log']: {repeated_batch['message_log']}", BLUE) for i, message_log in enumerate(repeated_batch["message_log"]): - # print_colored(f"length of message_log {i}: {len(message_log)}", BLUE) - # print_colored(f"message_log {i}: {message_log}", BLUE) for j, message in enumerate(message_log): - # print_colored(f"length of message {i}, {j}: {len(message)}", BLUE) - # print_colored(f"message {i}, {j}: {message}", BLUE) if message["role"] == "assistant": message["token_loss_mask"] = torch.ones_like( message["token_ids"] @@ -1542,7 +1543,6 @@ def grpo_train( message["token_ids"] ) if "generation_logprobs" not in message: - # print_colored(f"Set Generation logprobs to zeros for message {i}, {j}") message["generation_logprobs"] = torch.zeros_like( message["token_ids"], dtype=torch.float32 ) @@ -2137,6 +2137,8 @@ def async_grpo_train( policy_generation = policy NEED_REFIT = False POLICY_GENERATION_STALE = True + REFIT_BASE_MODEL_WEIGHTS = True + REFIT_LORA_WEIGHTS = policy.lora_enabled assert policy_generation is not None # Training state @@ -2248,9 +2250,16 @@ def async_grpo_train( if NEED_REFIT and POLICY_GENERATION_STALE: print("šŸ”„ Refitting policy generation with actual model weights...") try: - refit_policy_generation(policy, policy_generation, colocated_inference) + refit_policy_generation( + policy, + policy_generation, + colocated_inference, + refit_base_model_weights=REFIT_BASE_MODEL_WEIGHTS, + refit_lora_weights=REFIT_LORA_WEIGHTS, + ) print("āœ… Policy generation refit completed successfully") POLICY_GENERATION_STALE = False + REFIT_BASE_MODEL_WEIGHTS = False if REFIT_LORA_WEIGHTS else True except Exception as e: print(f"āŒ Policy generation refit failed: {e}") import traceback @@ -2568,10 +2577,14 @@ def async_grpo_train( print("šŸ”„ Performing policy generation refit...") with timer.time("weight_sync"): refit_policy_generation( - policy, policy_generation, colocated_inference + policy, + policy_generation, + colocated_inference, + refit_base_model_weights=REFIT_BASE_MODEL_WEIGHTS, + refit_lora_weights=REFIT_LORA_WEIGHTS, ) POLICY_GENERATION_STALE = False - + REFIT_BASE_MODEL_WEIGHTS = False if REFIT_LORA_WEIGHTS else True # Update weight version before resuming trajectory collection so that all trajectories are updated with the new correct weight version weight_version += 1 trajectory_collector.set_weight_version.remote(weight_version) @@ -2593,9 +2606,14 @@ def async_grpo_train( if NEED_REFIT and POLICY_GENERATION_STALE: refit_policy_generation( - policy, policy_generation, colocated_inference + policy, + policy_generation, + colocated_inference, + refit_base_model_weights=REFIT_BASE_MODEL_WEIGHTS, + refit_lora_weights=REFIT_LORA_WEIGHTS, ) POLICY_GENERATION_STALE = False + REFIT_BASE_MODEL_WEIGHTS = False if REFIT_LORA_WEIGHTS else True else: policy_generation.prepare_for_generation() val_metrics, validation_timings = validate( diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index 31f9536b02..5eafba0540 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -251,7 +251,9 @@ def update_weights_via_ipc_zmq( """Update the model weights from the given IPC handles.""" raise NotImplementedError - def update_weights_from_collective(self) -> list[ray.ObjectRef]: + def update_weights_from_collective( + self, refit_base_model_weights: bool = True, refit_lora_weights: bool = False + ) -> list[ray.ObjectRef]: """Update the model weights from collective communication.""" raise NotImplementedError diff --git a/nemo_rl/models/generation/lora.py b/nemo_rl/models/generation/lora.py index 0f6b1f655c..b4b187fe39 100644 --- a/nemo_rl/models/generation/lora.py +++ b/nemo_rl/models/generation/lora.py @@ -39,8 +39,8 @@ def patched_load_adapter(self, lora_request: LoRARequestWithCfgAndWeights): expected_lora_lst.append(module) expected_lora_modules = set(expected_lora_lst) lora_weights = None - if isinstance(lora_request, LoRARequestWithCfgAndWeights): + # if hasattr(lora_request, "lora_weights") and getattr(lora_request, "lora_weights", None) is not None: lora_cfg = lora_request.lora_cfg lora_weights = lora_request.lora_weights peft_helper = PEFTHelper.from_dict(lora_cfg) diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 7cefa3e512..efa3443595 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -23,9 +23,12 @@ calculate_aligned_size, rebuild_cuda_tensor_from_ipc, ) -from nemo_rl.utils.logger import RED, print_colored from nemo_rl.utils.nsys import wrap_with_nvtx_name from nemo_rl.utils.packed_tensor import packed_broadcast_consumer +from nemo_rl.utils.weights import ( + is_base_model_weight_name, + is_lora_weight_name, +) try: import vllm # noqa: F401 @@ -127,44 +130,122 @@ def _maybe_process_fp8_kv_cache(self) -> None: target_device, ) - def map_param_name(self, param_name: str) -> str: - lora_mgr = self.model_runner.model.lora_manager - supported_modules = lora_mgr.supported_lora_modules - packed_modules_mapping = lora_mgr.packed_modules_mapping - - parts = param_name.split(".") - if len(parts) < 2: - return param_name - - base_name = ".".join(parts[:-2]) # prefix - module_name = parts[-2] # e.g. q_proj/k_proj/v_proj/gate_proj/up_proj/... - field_name = parts[-1] # weight/bias - - resolved_module_name = module_name - for packed_name, member_names in packed_modules_mapping.items(): - if module_name in member_names: - resolved_module_name = packed_name - break + def apply_lora_patches(self) -> None: + """Apply LoRA patches inside the vLLM worker process.""" + try: + from nemo_rl.models.generation.lora import apply_lora_patches - # use resolved_module_name for checking, but return the original module_name - if resolved_module_name in supported_modules: - if base_name != "": - return f"{base_name}.{module_name}.base_layer.{field_name}" - else: - return f"{module_name}.base_layer.{field_name}" + apply_lora_patches() + except Exception as e: + print(f"Failed to apply LoRA patches in worker extension: {e}") + import traceback as _tb - return param_name + print(_tb.format_exc()) + raise e def _apply_weight_name_mapping( self, weights: list[tuple[str, torch.Tensor]] ) -> list[tuple[str, torch.Tensor]]: """Apply weight name mapping if LoRA is enabled.""" + + def map_param_name(param_name: str) -> str: + lora_mgr = self.model_runner.model.lora_manager + supported_modules = lora_mgr.supported_lora_modules + packed_modules_mapping = lora_mgr.packed_modules_mapping + + parts = param_name.split(".") + if len(parts) < 2: + return param_name + + base_name = ".".join(parts[:-2]) # prefix + module_name = parts[-2] # e.g. q_proj/k_proj/v_proj/gate_proj/up_proj/... + field_name = parts[-1] # weight/bias + + resolved_module_name = module_name + for packed_name, member_names in packed_modules_mapping.items(): + if module_name in member_names: + resolved_module_name = packed_name + break + + # use resolved_module_name for checking, but return the original module_name + if resolved_module_name in supported_modules: + if base_name != "": + return f"{base_name}.{module_name}.base_layer.{field_name}" + else: + return f"{module_name}.base_layer.{field_name}" + return param_name + new_weights = [] for name, w in weights: - new_name = self.map_param_name(name) + new_name = map_param_name(name) new_weights.append((new_name, w)) return new_weights + def _apply_loaded_weights( + self, + weights: list[tuple[str, torch.Tensor]], + lora_config: dict[str, Any], + refit_base_model_weights: bool, + refit_lora_weights: bool, + ) -> None: + """Apply loaded weights to model or LoRA based on flags. + + This unifies the duplicate logic used by both IPC and collective paths. + """ + from nemo_rl.models.generation import fp8 + + runner = self.model_runner + + if fp8.is_fp8_model(runner.vllm_config): + # the fp8 load_weights additionally casts bf16 weights into fp8 + fp8.load_weights(weights, runner) + return + + if refit_base_model_weights: + if lora_config and "enabled" in lora_config and lora_config["enabled"]: + weights = self._apply_weight_name_mapping(weights) + runner.model.load_weights(weights=weights) + return + + if refit_lora_weights: + assert lora_config, ( + "lora_config is not provided, can not refit lora weights" + ) + from nemo_rl.models.generation.lora import ( + LoRARequestWithCfgAndWeights, + get_vllm_lora_metadata, + ) + + lora_cfg_dict = dict( + { + "r": lora_config["dim"], + "lora_alpha": lora_config["alpha"], + "target_modules": lora_config["target_modules"], + } + ) + lora_metadata = get_vllm_lora_metadata() + # Note: We don't need to remove the lora if it is already set max_loras = 1 + self.remove_lora(lora_id=lora_metadata["lora_int_id"]) + lora_request = LoRARequestWithCfgAndWeights( + **lora_metadata, + lora_cfg=lora_cfg_dict, + lora_weights=dict({name: tensor for name, tensor in weights}), + ) + try: + self.add_lora(lora_request=lora_request) + except Exception as e: + print( + f"Error in VllmInternalWorkerExtension._apply_loaded_weights: {e}" + ) + print(traceback.format_exc()) + raise e + # self.add_lora(lora_request=lora_request) + return + + raise ValueError( + "refit_base_model_weights and refit_lora_weights cannot be both False" + ) + @wrap_with_nvtx_name("vllm_internal_worker_extension/update_weights_via_ipc_zmq") def update_weights_via_ipc_zmq( self, @@ -180,10 +261,6 @@ def update_weights_via_ipc_zmq( buffer = None weights = None - print_colored( - f"lora_config in update_weights_via_ipc_zmq: {self.model_runner.vllm_config.lora_config}", - RED, - ) try: self.maybe_init_zmq() while True: @@ -225,61 +302,13 @@ def update_weights_via_ipc_zmq( assert offset == used_bytes, ( "Offset is not equal to used bytes, usually indicate inaccurate info like keys or cached dtype in state_dict_info" ) - # Load weights into the model - from nemo_rl.models.generation import fp8 - - if fp8.is_fp8_model(self.model_runner.vllm_config): - # the fp8 load_weights additionally casts bf16 weights into fp8 - fp8.load_weights(weights, self.model_runner) - else: - if refit_base_model_weights: - # Apply weight name mapping if LoRA is enabled - if ( - lora_config - and "enabled" in lora_config - and lora_config["enabled"] - ): - weights = self._apply_weight_name_mapping(weights) - self.model_runner.model.load_weights(weights=weights) - print_colored("updated base model weights", RED) - elif refit_lora_weights: - assert lora_config, ( - "lora_config is not provided, can not refit lora weights" - ) - from nemo_rl.models.generation.lora import ( - LoRARequestWithCfgAndWeights, - get_vllm_lora_metadata, - ) - - # Convert vLLM LoRAConfig object to dict for PEFTHelper - # LoRAConfig(max_lora_rank=8, max_loras=1, fully_sharded_loras=False, max_cpu_loras=1, lora_dtype=torch.bfloat16, lora_extra_vocab_size=256, default_mm_loras=None, bias_enabled=False) - - lora_cfg_dict = dict( - { - "r": lora_config["dim"], - "lora_alpha": lora_config["alpha"], - "target_modules": lora_config["target_modules"], - } - ) - lora_metadata = get_vllm_lora_metadata() - # Note: We don't need to remove the lora if it is already set max_loras = 1 - self.remove_lora(lora_id=lora_metadata["lora_int_id"]) - lora_request = LoRARequestWithCfgAndWeights( - **lora_metadata, - lora_cfg=lora_cfg_dict, - lora_weights=dict( - { - name_weight[0]: name_weight[1] - for name_weight in weights - } - ), - ) - self.add_lora(lora_request=lora_request) - print_colored("updated lora weights", RED) - else: - raise ValueError( - "refit_base_model_weights and refit_lora_weights cannot be both False" - ) + # Load weights into the model or LoRA + self._apply_loaded_weights( + weights=weights, + lora_config=lora_config, + refit_base_model_weights=refit_base_model_weights, + refit_lora_weights=refit_lora_weights, + ) torch.cuda.current_stream().synchronize() @@ -309,36 +338,38 @@ def update_weights_via_ipc_zmq( @wrap_with_nvtx_name( "vllm_internal_worker_extension/update_weights_from_collective" ) - def update_weights_from_collective(self) -> bool: + def update_weights_from_collective( + self, + lora_config: dict[str, Any] = {}, + refit_base_model_weights: bool = True, + refit_lora_weights: bool = False, + ) -> bool: """Update the model weights from collective communication.""" assert self.state_dict_info is not None, ( "state_dict_info is not prepared. " "Please call prepare_refit_info when initializing the worker." ) - def _load_model_weights(weights, model_runner): - """Load model weights. - - Args: - weights: List[(name, tensor)] - model_runner: vLLM ModelRunner - - Returns: - None - """ - from nemo_rl.models.generation import fp8 - - if fp8.is_fp8_model(model_runner.vllm_config): - # the fp8 load_weights additionally casts bf16 weights into fp8 - fp8.load_weights(weights, model_runner) - else: - model_runner.model.load_weights(weights=weights) - - load_model_weight_func = lambda x: _load_model_weights(x, self.model_runner) + def _filtered_state_dict_iterator(): + """Iterator that yields only base model weights when skip_base_model_weights is True.""" + for name, tensor_tuple in self.state_dict_info.items(): + # Skip base model weights if skip_base_model_weights is True + if is_base_model_weight_name(name) and not refit_base_model_weights: + continue + if is_lora_weight_name(name) and not refit_lora_weights: + continue + yield name, tensor_tuple + + load_model_weight_func = lambda weights: self._apply_loaded_weights( + weights=weights, + lora_config=lora_config, + refit_base_model_weights=refit_base_model_weights, + refit_lora_weights=refit_lora_weights, + ) try: packed_broadcast_consumer( - iterator=iter(self.state_dict_info.items()), + iterator=_filtered_state_dict_iterator(), group=self.model_update_group, src=0, post_unpack_func=load_model_weight_func, diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 751c02db4b..b8b07d4de9 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -792,7 +792,9 @@ def update_weights_via_ipc_zmq( # this function should co-work with lm_policy, so we should wait for all futures to complete outside return futures - def update_weights_from_collective(self) -> list[ray.ObjectRef]: + def update_weights_from_collective( + self, refit_base_model_weights: bool = True, refit_lora_weights: bool = False + ) -> list[ray.ObjectRef]: """Update weights of the policy using collective communication.""" if not self.worker_group or not self.worker_group.workers: raise RuntimeError("Worker group is not initialized") @@ -808,6 +810,8 @@ def update_weights_from_collective(self) -> list[ray.ObjectRef]: futures = self.worker_group.run_all_workers_single_data( method_name, run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], + refit_base_model_weights=refit_base_model_weights, + refit_lora_weights=refit_lora_weights, ) # this function should co-work with lm_policy, so we should wait for all futures to complete outside diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index 7ac3071b48..be4d25c977 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -34,7 +34,6 @@ from nemo_rl.models.generation.vllm.utils import format_prompt_for_vllm_generation from nemo_rl.models.huggingface.common import ModelFlag from nemo_rl.models.policy.utils import is_vllm_v1_engine_enabled -from nemo_rl.utils.logger import print_colored from nemo_rl.utils.nsys import wrap_with_nvtx_name @@ -393,9 +392,12 @@ def _patch_vllm_vit_flash_attn_backend(): # Lora is enabled, add it to the vllm kwargs self.lora_enabled = False if self.lora_cfg is not None and self.lora_cfg["enabled"]: - from nemo_rl.models.generation.lora import apply_lora_patches + try: + from nemo_rl.models.generation.lora import apply_lora_patches - apply_lora_patches() + apply_lora_patches() + except Exception as e: + print(f"[WARNING] Failed to apply lora patches (sync worker): {e}") self.lora_enabled = True vllm_kwargs["enable_lora"] = True vllm_kwargs["max_loras"] = 1 # only support one lora adapter @@ -575,7 +577,6 @@ def generate( lora_req = None if self.lora_enabled: - # print_colored(f"list_lora in generate: {self.llm.llm_engine.list_loras()}") from vllm.lora.request import LoRARequest from nemo_rl.models.generation.lora import get_vllm_lora_metadata @@ -584,7 +585,6 @@ def generate( lora_req = LoRARequest( **lora_metadata, ) - print_colored(f"lora_req in generate: {lora_req}") outputs = self.llm.generate(prompts, sampling_params, lora_request=lora_req) # Process the outputs - but preserve the original input padding structure @@ -778,7 +778,9 @@ def update_weights_via_ipc_zmq( return False @wrap_with_nvtx_name("vllm_genertion_worker/update_weights_from_collective") - def update_weights_from_collective(self) -> bool: + def update_weights_from_collective( + self, refit_base_model_weights: bool = True, refit_lora_weights: bool = False + ) -> bool: """Update the model weights from collective communication.""" try: assert self.llm is not None, ( @@ -791,7 +793,8 @@ def update_weights_from_collective(self) -> bool: ) result_or_coro = self.llm.collective_rpc( - "update_weights_from_collective", args=tuple() + "update_weights_from_collective", + args=(self.lora_cfg, refit_base_model_weights, refit_lora_weights), ) worker_result = result_or_coro[0] @@ -820,7 +823,6 @@ def _get_lora_layers(self): details = [] for name, module in model.named_modules(): - print_colored(f"name: {name}") if isinstance(module, BaseLinearLayerWithLoRA): a_shapes = [tuple(t.shape) for t in module.lora_a_stacked] b_shapes = [tuple(t.shape) for t in module.lora_b_stacked] diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index 0e4ea5cdeb..c687098f4f 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -276,6 +276,17 @@ def clear_vllm_logger_metrics(self) -> None: async def post_init_async(self): self.vllm_device_ids = await self.report_device_id_async() + # Ensure LoRA patches are applied inside engine worker processes (async path) + if getattr(self, "lora_enabled", False) and self.llm is not None: + try: + await self.llm.collective_rpc("apply_lora_patches", args=tuple()) + print( + "Successfully applied lora patches in engine workers (async worker)" + ) + except Exception as e: + print( + f"[WARNING] Failed to apply lora patches in engine workers (async worker): {e}" + ) async def report_dp_openai_server_base_url(self) -> Optional[str]: return self.base_url @@ -1025,7 +1036,9 @@ async def update_weights_via_ipc_zmq_async( traceback.print_exc() return False - async def update_weights_from_collective_async(self) -> bool: + async def update_weights_from_collective_async( + self, refit_base_model_weights: bool = True, refit_lora_weights: bool = False + ) -> bool: """Async version of update_weights_from_collective.""" try: assert self.llm is not None, ( @@ -1038,7 +1051,8 @@ async def update_weights_from_collective_async(self) -> bool: ) result_or_coro = await self.llm.collective_rpc( - "update_weights_from_collective", args=tuple() + "update_weights_from_collective", + args=(self.lora_cfg, refit_base_model_weights, refit_lora_weights), ) if asyncio.iscoroutine(result_or_coro): diff --git a/nemo_rl/models/policy/interfaces.py b/nemo_rl/models/policy/interfaces.py index 144b0c517d..7d5f6167ac 100644 --- a/nemo_rl/models/policy/interfaces.py +++ b/nemo_rl/models/policy/interfaces.py @@ -184,7 +184,10 @@ def stream_weights_via_ipc_zmq( @abstractmethod def broadcast_weights_for_collective( - self, kv_scales: Optional[dict[str, float]] = None + self, + kv_scales: Optional[dict[str, float]] = None, + refit_base_model_weights: bool = True, + refit_lora_weights: bool = False, ) -> list[ray.ObjectRef]: pass diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 521b9945a9..948bad5192 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -114,9 +114,7 @@ def __init__( use_v2 = config.get("dtensor_cfg", {}).get("_v2", False) lora_cfg = config.get("dtensor_cfg", {}).get("lora_cfg", {}) self.lora_enabled = lora_cfg.get("enabled", False) - from nemo_rl.utils.logger import print_colored - print_colored(f"LORA ENABLED in lm_policy: {self.lora_enabled}") if use_v2: worker_builder_cls = "nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2" @@ -782,12 +780,17 @@ def stream_weights_via_ipc_zmq( return futures def broadcast_weights_for_collective( - self, kv_scales: Optional[dict[str, float]] = None + self, + kv_scales: Optional[dict[str, float]] = None, + refit_base_model_weights: bool = True, + refit_lora_weights: bool = False, ) -> list[ray.ObjectRef]: """Broadcast the weights for collective communication.""" futures = self.worker_group.run_all_workers_single_data( "broadcast_weights_for_collective", kv_scales=kv_scales, + refit_base_model_weights=refit_base_model_weights, + refit_lora_weights=refit_lora_weights, ) # this function should co-work with vllm, so we should wait for all futures to complete outside return futures diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index ccd8d6053e..7be97e21d1 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -1725,9 +1725,15 @@ def dtensor_params_generator(): @torch.no_grad() def broadcast_weights_for_collective( - self, kv_scales: Optional[dict[str, float]] = None + self, + kv_scales: Optional[dict[str, float]] = None, + refit_base_model_weights: bool = True, + refit_lora_weights: bool = False, ) -> None: """Broadcast the weights for collective communication.""" + assert refit_base_model_weights and refit_lora_weights == False, ( + "dtensor v1 not support lora. refit_lora_weights must be False" + ) if kv_scales is not None: raise NotImplementedError( "FP8 kvcache is not currently supported for DTensor path, we will support it in the future." diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 02b3d45801..6a6c3f5b71 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -91,6 +91,7 @@ from nemo_rl.utils.checkpoint import CheckpointingConfig from nemo_rl.utils.nsys import wrap_with_nvtx_name from nemo_rl.utils.packed_tensor import packed_broadcast_producer +from nemo_rl.utils.weights import is_base_model_weight_name, is_lora_weight_name STRING_TO_DTYPE = { "float32": torch.float32, @@ -292,7 +293,7 @@ def __init__( ] else: sdpa_method = None - + print(f"[Rank {self.rank}] sdpa_method: {sdpa_method}") self.model = model_class.from_pretrained( model_name, attn_implementation=attn_impl, @@ -1689,18 +1690,6 @@ def return_model_config(self) -> dict[str, Any]: """ return self.model.config - def _is_lora_weight(self, name: str) -> bool: - """Check if the weight is a lora weight.""" - return ( - name.endswith(".lora_A.weight") - or name.endswith(".lora_B.weight") - or name.endswith(".lora_scaling.weight") - ) - - def _is_base_model_weight(self, name: str) -> bool: - """Check if the weight is a base model weight.""" - return not self._is_lora_weight(name) - @torch.no_grad() def prepare_refit_info(self) -> Optional[dict[str, Any]]: """Prepare state dict metadata for weight refitting and IPC streaming.""" @@ -1754,10 +1743,10 @@ def dtensor_params_generator(): """ for name, tensor in self.model.state_dict().items(): # Skip base model weights if skip_base_model_weights is True - if self._is_base_model_weight(name) and not refit_base_model_weights: + if is_base_model_weight_name(name) and not refit_base_model_weights: continue - if self._is_lora_weight(name) and not refit_lora_weights: + if is_lora_weight_name(name) and not refit_lora_weights: continue if isinstance(tensor, DTensor): @@ -1816,9 +1805,9 @@ def _filtered_state_dict_iterator(): """Iterator that yields only base model weights when skip_base_model_weights is True.""" for name, tensor in self.model.state_dict().items(): # Skip base model weights if skip_base_model_weights is True - if self._is_base_model_weight(name) and not refit_base_model_weights: + if is_base_model_weight_name(name) and not refit_base_model_weights: continue - if self._is_lora_weight(name) and not refit_lora_weights: + if is_lora_weight_name(name) and not refit_lora_weights: continue yield (name, tensor) diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index caece09081..66767822e6 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -2118,11 +2118,7 @@ def _iter_params_with_optional_kv_scales( @torch.no_grad() @wrap_with_nvtx_name("megatron_policy_worker/stream_weights_via_ipc_zmq") def stream_weights_via_ipc_zmq( - self, - buffer_size_bytes: int = 0, - kv_scales: Optional[dict[str, float]] = None, - refit_base_model_weights: bool = True, - refit_lora_weights: bool = False, + self, buffer_size_bytes: int = 0, kv_scales: Optional[dict[str, float]] = None ) -> None: """Stream model weights to peer process via ZMQ IPC socket.""" self.maybe_init_zmq() diff --git a/nemo_rl/utils/weights.py b/nemo_rl/utils/weights.py new file mode 100644 index 0000000000..db2a5a4c28 --- /dev/null +++ b/nemo_rl/utils/weights.py @@ -0,0 +1,12 @@ +def is_lora_weight_name(name: str) -> bool: + """Return True if a parameter name corresponds to a LoRA weight.""" + return ( + name.endswith(".lora_A.weight") + or name.endswith(".lora_B.weight") + or name.endswith(".lora_scaling.weight") + ) + + +def is_base_model_weight_name(name: str) -> bool: + """Return True if a parameter name corresponds to a base (non-LoRA) weight.""" + return not is_lora_weight_name(name) diff --git a/tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh b/tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh new file mode 100644 index 0000000000..bd17b54635 --- /dev/null +++ b/tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh @@ -0,0 +1,41 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=20 +MAX_STEPS=20 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=30 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo_math.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/gen_kl_error"]) < 0.001' \ + 'data["train/gen_kl_error"]["20"] < 0.001' \ + 'mean(data["train/reward"]) > 0.56' \ + 'mean(data["timing/train/total_step_time"], 2) < 50' +fi From 7ffe5e1885e4dadb8f19f5bb835dcc59ed645fa3 Mon Sep 17 00:00:00 2001 From: ruit Date: Sun, 28 Dec 2025 19:35:20 -0800 Subject: [PATCH 4/6] add funtional test Signed-off-by: ruit --- .../workers/dtensor_policy_worker_v2.py | 3 +- tests/functional/test_automodel_lora_grpo.sh | 46 +++++++++++++++++++ .../llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh | 11 ++--- .../models/generation/test_vllm_generation.py | 16 +++---- 4 files changed, 61 insertions(+), 15 deletions(-) create mode 100644 tests/functional/test_automodel_lora_grpo.sh diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 6a6c3f5b71..e910984ab7 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -277,7 +277,8 @@ def __init__( # NeMoAutoModelForCausalLM uses flash_attention_2 by default # so we need to set it to None if sequence packing is disabled # https://github.com/NVIDIA-NeMo/Automodel/blob/7e748be260651349307862426c0c168cebdeeec3/nemo_automodel/components/_transformers/auto_model.py#L180 - if cp_size > 1 or self.cfg["dtensor_cfg"]["activation_checkpointing"]: + # if cp_size > 1 or self.cfg["dtensor_cfg"]["activation_checkpointing"]: + if cp_size > 1: # For cp, match Automodel's `get_train_context` in `cp_utils.py` where only # flash and efficient backends are supported # Ref: https://github.com/NVIDIA-NeMo/Automodel/blob/81788d6f4848f5f066c4a6a2bece4689a6a83687/nemo_automodel/components/distributed/cp_utils.py#L57 diff --git a/tests/functional/test_automodel_lora_grpo.sh b/tests/functional/test_automodel_lora_grpo.sh new file mode 100644 index 0000000000..1395250b5d --- /dev/null +++ b/tests/functional/test_automodel_lora_grpo.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# clean up checkpoint directory on exit +trap "rm -rf /tmp/lora_sft_checkpoints" EXIT + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_grpo_math.py\ + grpo.max_num_steps=3 \ + grpo.num_prompts_per_step=32 \ + grpo.num_generations_per_prompt=16 \ + policy.dtensor_cfg.lora_cfg.enabled=True \ + policy.dtensor_cfg.lora_cfg.dim=32 \ + +policy.generation.vllm_cfg.skip_tokenizer_init=false \ + policy.dtensor_cfg.tensor_parallel_size=1 \ + policy.train_global_batch_size=512 \ + policy.train_micro_batch_size=4 \ + logger.wandb_enabled=False \ + checkpointing.enabled=false \ + cluster.gpus_per_node=8 \ + checkpointing.checkpoint_dir=/tmp/lora_grpo_checkpoints \ + "$@" \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/reward"]["3"] > 0.07' + diff --git a/tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh b/tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh index bd17b54635..d8897855a0 100644 --- a/tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh +++ b/tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh @@ -4,8 +4,8 @@ source $SCRIPT_DIR/common.env # ===== BEGIN CONFIG ===== NUM_NODES=1 -STEPS_PER_RUN=20 -MAX_STEPS=20 +STEPS_PER_RUN=10 +MAX_STEPS=10 NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up NUM_MINUTES=30 # ===== END CONFIG ===== @@ -34,8 +34,7 @@ uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS # Only run metrics if the target step is reached if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then uv run tests/check_metrics.py $JSON_METRICS \ - 'mean(data["train/gen_kl_error"]) < 0.001' \ - 'data["train/gen_kl_error"]["20"] < 0.001' \ - 'mean(data["train/reward"]) > 0.56' \ - 'mean(data["timing/train/total_step_time"], 2) < 50' + 'mean(data["train/gen_kl_error"]) < 0.002' \ + 'mean(data["train/reward"]) > 0.30' \ + 'mean(data["timing/train/total_step_time"], 2) < 150' fi diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 4cbc7b2bdc..0546ffa45e 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -38,9 +38,7 @@ from nemo_rl.models.policy import LoRAConfig, PolicyConfig from nemo_rl.models.policy.lm_policy import Policy -# model_name = "Qwen/Qwen3-0.6B" -# model_name = "Qwen/Qwen2.5-1.5B" -model_name = "unsloth/Llama-3.2-1B-Instruct" +model_name = "Qwen/Qwen3-0.6B" # Define basic vLLM test config basic_vllm_test_config: VllmConfig = { "backend": "vllm", @@ -935,12 +933,14 @@ async def test_vllm_generation_with_hf_training_colocated( @pytest.mark.timeout(300) @pytest.mark.asyncio @pytest.mark.parametrize( - ("async_engine", "cpu_offload", "vllm_precision"), + ("async_engine", "cpu_offload", "vllm_precision", "enable_lora"), [ - (True, False, "bfloat16"), - (False, True, "bfloat16"), - (True, False, "fp8"), - (False, True, "fp8"), + (True, False, "bfloat16", True), + (False, True, "bfloat16", True), + (True, False, "bfloat16", False), + (False, True, "bfloat16", False), + (True, False, "fp8", False), + (False, True, "fp8", False), ], ) async def test_vllm_generation_with_hf_training_non_colocated( From 4443948f5004288c188bff246a7db09a935b0270 Mon Sep 17 00:00:00 2001 From: ruit Date: Fri, 2 Jan 2026 03:19:12 -0800 Subject: [PATCH 5/6] add unit test and functional test Signed-off-by: ruit --- nemo_rl/algorithms/grpo.py | 3 + .../generation/vllm/vllm_worker_async.py | 5 +- .../policy/workers/megatron_policy_worker.py | 6 +- .../llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh | 0 tests/test_suites/nightly.txt | 3 + .../models/generation/test_vllm_generation.py | 201 ++++++------------ 6 files changed, 76 insertions(+), 142 deletions(-) mode change 100644 => 100755 tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 298fcf52c3..45dc43b0a5 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -517,6 +517,9 @@ def init_vllm(): assert loss_config["use_importance_sampling_correction"] is True, ( "Importance sampling must be enabled for vLLM FP8 generation for good convergence!" ) + assert not policy_config["dtensor_cfg"]["lora_cfg"]["enabled"], ( + "LoRA is not supported with vLLM FP8 generation." + ) if generation_config["vllm_cfg"]["kv_cache_dtype"].startswith("fp8"): # FP8 KV cache requires FP8 model precision assert generation_config["vllm_cfg"]["precision"] == "fp8", ( diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index c687098f4f..0f7afe0ed3 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -998,7 +998,7 @@ async def prepare_refit_info_async(self, state_dict_info: dict[str, Any]) -> Non await self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,)) async def update_weights_via_ipc_zmq_async( - self, + self, refit_base_model_weights: bool = True, refit_lora_weights: bool = False ) -> bool: """Async version of update_weights_via_ipc_zmq.""" try: @@ -1013,7 +1013,8 @@ async def update_weights_via_ipc_zmq_async( # TODO: switch to update_weights_from_local_ipc_handles for better performance once collectively report_device_id is supported in asyncLLM initialization result_or_coro = await self.llm.collective_rpc( - "update_weights_via_ipc_zmq", args=tuple() + "update_weights_via_ipc_zmq", + args=(self.lora_cfg, refit_base_model_weights, refit_lora_weights), ) if asyncio.iscoroutine(result_or_coro): diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 66767822e6..caece09081 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -2118,7 +2118,11 @@ def _iter_params_with_optional_kv_scales( @torch.no_grad() @wrap_with_nvtx_name("megatron_policy_worker/stream_weights_via_ipc_zmq") def stream_weights_via_ipc_zmq( - self, buffer_size_bytes: int = 0, kv_scales: Optional[dict[str, float]] = None + self, + buffer_size_bytes: int = 0, + kv_scales: Optional[dict[str, float]] = None, + refit_base_model_weights: bool = True, + refit_lora_weights: bool = False, ) -> None: """Stream model weights to peer process via ZMQ IPC socket.""" self.maybe_init_zmq() diff --git a/tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh b/tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh old mode 100644 new mode 100755 diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index e95507105a..df1de61b01 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -59,6 +59,9 @@ tests/test_suites/llm/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated.sh tests/test_suites/llm/grpo-nano-v2-12b-1n8g-megatron.sh tests/test_suites/llm/grpo-nano-v2-12b-2n8g-fsdp2tp1.sh +# lora +tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh + ####### # SFT # ####### diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 0546ffa45e..9bd5e5b2ff 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -70,6 +70,7 @@ "skip_tokenizer_init": False, "load_format": "auto", "enforce_eager": "False", + "kv_cache_dtype": "auto", }, "colocated": { "enabled": True, @@ -105,6 +106,7 @@ }, }, "dtensor_cfg": { + "_v2": False, "enabled": True, "cpu_offload": False, "sequence_parallel": False, @@ -128,7 +130,7 @@ } basic_lora_test_config: LoRAConfig = { - "enabled": True, + "enabled": False, "target_modules": [], "exclude_modules": [], "match_all_linear": True, @@ -137,7 +139,7 @@ "dropout": 0.0, "dropout_position": "post", "lora_A_init": "xavier", - "use_triton": True, + "use_triton": False, } @@ -701,7 +703,13 @@ def configure_worker_fixed_seed(num_gpus, bundle_indices=None): async def run_hf_train_process( - lm_policy, vllm_policy, tokenizer, async_engine, colocated, vllm_precision + lm_policy, + vllm_policy, + tokenizer, + async_engine, + colocated, + vllm_precision, + enable_lora, ): """Validates that the two policies can work together. @@ -742,7 +750,13 @@ async def run_hf_train_process( ) print("refitting vllm policy...") - refit_policy_generation(lm_policy, vllm_policy, colocated) + refit_policy_generation( + lm_policy, + vllm_policy, + colocated, + refit_base_model_weights=True, + refit_lora_weights=enable_lora, + ) # Step 1: Use vLLM for generation print("Using vLLM policy for fast generation...") @@ -881,16 +895,17 @@ async def run_hf_train_process( @pytest.mark.timeout(300) @pytest.mark.asyncio @pytest.mark.parametrize( - ("async_engine", "cpu_offload", "vllm_precision"), + ("async_engine", "cpu_offload", "vllm_precision", "enable_lora"), [ - (True, False, "bfloat16"), - (False, True, "bfloat16"), - (True, False, "fp8"), - (False, True, "fp8"), + (True, False, "bfloat16", False), + (False, True, "bfloat16", False), + (True, False, "bfloat16", True), + (True, False, "fp8", False), + (False, True, "fp8", False), ], ) async def test_vllm_generation_with_hf_training_colocated( - cluster, tokenizer, async_engine, cpu_offload, vllm_precision + cluster, tokenizer, async_engine, cpu_offload, vllm_precision, enable_lora ): """This test validates that DTensor policy can work together with colocated vLLM policy.""" @@ -907,6 +922,8 @@ async def test_vllm_generation_with_hf_training_colocated( vllm_config = deepcopy(basic_vllm_test_config) vllm_config["vllm_cfg"]["async_engine"] = async_engine vllm_config["vllm_cfg"]["precision"] = vllm_precision + vllm_config["vllm_cfg"]["lora_cfg"] = deepcopy(basic_lora_test_config) + vllm_config["vllm_cfg"]["lora_cfg"]["enabled"] = enable_lora vllm_config = configure_generation_config(vllm_config, tokenizer) vllm_policy = VllmGeneration(cluster, vllm_config) @@ -916,6 +933,9 @@ async def test_vllm_generation_with_hf_training_colocated( print("Creating DTensor policy...") dtensor_config = deepcopy(basic_dtensor_test_config) dtensor_config["dtensor_cfg"]["cpu_offload"] = cpu_offload + dtensor_config["dtensor_cfg"]["_v2"] = enable_lora + dtensor_config["dtensor_cfg"]["lora_cfg"] = deepcopy(basic_lora_test_config) + dtensor_config["dtensor_cfg"]["lora_cfg"]["enabled"] = enable_lora dtensor_config["train_global_batch_size"] = 4 lm_policy = Policy(cluster, dtensor_config, tokenizer) @@ -926,7 +946,13 @@ async def test_vllm_generation_with_hf_training_colocated( # Test await run_hf_train_process( - lm_policy, vllm_policy, tokenizer, async_engine, True, vllm_precision + lm_policy, + vllm_policy, + tokenizer, + async_engine, + True, + vllm_precision, + enable_lora, ) @@ -935,16 +961,20 @@ async def test_vllm_generation_with_hf_training_colocated( @pytest.mark.parametrize( ("async_engine", "cpu_offload", "vllm_precision", "enable_lora"), [ - (True, False, "bfloat16", True), - (False, True, "bfloat16", True), (True, False, "bfloat16", False), (False, True, "bfloat16", False), + (True, False, "bfloat16", True), (True, False, "fp8", False), (False, True, "fp8", False), ], ) async def test_vllm_generation_with_hf_training_non_colocated( - policy_cluster_separate, tokenizer, async_engine, cpu_offload, vllm_precision + policy_cluster_separate, + tokenizer, + async_engine, + cpu_offload, + vllm_precision, + enable_lora, ): # Skip the fp8 tests if the GPU is not H100 or newer (compute capability < 9.0) if vllm_precision == "fp8": @@ -960,19 +990,30 @@ async def test_vllm_generation_with_hf_training_non_colocated( # Create VllmGeneration Policy print("Creating vLLM policy...") vllm_config = deepcopy(basic_vllm_test_config) + vllm_config["vllm_cfg"]["lora_cfg"] = deepcopy(basic_lora_test_config) vllm_config["vllm_cfg"]["async_engine"] = async_engine vllm_config["vllm_cfg"]["precision"] = vllm_precision + vllm_config["vllm_cfg"]["lora_cfg"]["enabled"] = enable_lora vllm_config["colocated"]["enabled"] = False + if vllm_precision == "fp8": + vllm_config["vllm_cfg"]["kv_cache_dtype"] = "fp8" vllm_config = configure_generation_config(vllm_config, tokenizer) vllm_policy = VllmGeneration(generation_cluster_separate, vllm_config) vllm_policy.finish_generation() + assert not (enable_lora and vllm_precision == "fp8"), ( + "LoRA is not supported with FP8" + ) # Create Policy print("Creating DTensor policy...") dtensor_config = deepcopy(basic_dtensor_test_config) dtensor_config["generation"]["colocated"]["enabled"] = False dtensor_config["dtensor_cfg"]["cpu_offload"] = cpu_offload dtensor_config["train_global_batch_size"] = 4 + # lora must use dtensor v2 + dtensor_config["dtensor_cfg"]["_v2"] = enable_lora + dtensor_config["dtensor_cfg"]["lora_cfg"] = deepcopy(basic_lora_test_config) + dtensor_config["dtensor_cfg"]["lora_cfg"]["enabled"] = enable_lora lm_policy = Policy(policy_cluster_separate, dtensor_config, tokenizer) # Refit @@ -995,7 +1036,13 @@ async def test_vllm_generation_with_hf_training_non_colocated( # Test await run_hf_train_process( - lm_policy, vllm_policy, tokenizer, async_engine, False, vllm_precision + lm_policy, + vllm_policy, + tokenizer, + async_engine, + False, + vllm_precision, + enable_lora, ) @@ -2533,127 +2580,3 @@ def test_vllm_megatron_weight_update_with_packing(cluster, test_input_data): megatron_policy.shutdown() if vllm_generation: vllm_generation.shutdown() - - -from nemo_rl.utils.logger import BLUE, print_colored - - -def test_vllm_lora_refit_sync_colocated(cluster, tokenizer): - """Test vLLM LoRA refit with sync engine and colocated setup.""" - vllm_config = deepcopy(basic_vllm_test_config) - vllm_config["vllm_cfg"]["lora_cfg"] = deepcopy(basic_lora_test_config) - vllm_config["vllm_cfg"]["lora_cfg"]["enabled"] = True - vllm_config["vllm_cfg"]["async_engine"] = False - vllm_config = configure_generation_config(vllm_config, tokenizer) - - dtensor_config = deepcopy(basic_dtensor_test_config) - dtensor_config["dtensor_cfg"]["_v2"] = True - dtensor_config["dtensor_cfg"]["lora_cfg"] = deepcopy(basic_lora_test_config) - dtensor_config["dtensor_cfg"]["lora_cfg"]["enabled"] = True - - print_colored("CREATING DTENSOR POLICY", BLUE) - lm_policy = Policy(cluster, dtensor_config, tokenizer) - - print_colored("CREATING VLLM POLICY", BLUE) - vllm_policy = VllmGeneration(cluster, vllm_config) - vllm_policy.finish_generation() - - print_colored("PREPARING REFIT INFO", BLUE) - state_dict_info = lm_policy.prepare_refit_info() - vllm_policy.prepare_refit_info(state_dict_info) - # take it outside statistics to get clean peak memory during refit - lm_policy.offload_before_refit() - - print_colored("STARTING VLLM POLICY REFIT BASE MODEL WEIGHTS", BLUE) - refit_policy_generation( - lm_policy, - vllm_policy, - vllm_config["colocated"]["enabled"], - _refit_buffer_size_gb=3, - refit_base_model_weights=True, - refit_lora_weights=vllm_config["vllm_cfg"]["lora_cfg"]["enabled"], - ) - - # vllm_model_state_dict = vllm_policy.get_model_state_dict()[0][0] - # print_colored(f"VLLM MODEL STATE DICT: {vllm_model_state_dict.keys()}", BLUE) - # from transformers import AutoModel - - # model = AutoModel.from_pretrained(model_name) - # model_state_dict = model.state_dict() - - # for name, vllm_tensor in vllm_model_state_dict.items(): - # name = name.replace("model.", "") - # model_tensor = model_state_dict[name] - # print_colored( - # f"NAME: {name}, vllm type : {vllm_tensor.dtype}, model type: {model_tensor.dtype}", - # BLUE, - # ) - # vllm_tensor = vllm_tensor.to("cpu") - # model_tensor = model_tensor.to(vllm_tensor.dtype).to("cpu") - # if not torch.allclose(vllm_tensor, model_tensor): - # print_colored(f"Tensor {name} is not close", RED) - # print_colored(f"MODEL TENSOR: {model_tensor.shape}", RED) - # print_colored(f"VLLM TENSOR: {vllm_tensor.shape}", RED) - # assert False, f"Tensor {name} is not close" - - print_colored("GENERATING TEXT", BLUE) - prompts = [ - "What is the largest number, all of whose digits are 1 or 4, and whose digits add up to 12?" - ] - test_tokenizer = get_tokenizer({"name": model_name}) - tokenized = test_tokenizer( - prompts, - padding=True, - truncation=True, - max_length=256, - return_tensors="pt", - padding_side="right", - ) - test_input_data = BatchedDataDict( - { - "input_ids": tokenized["input_ids"], - "input_lengths": tokenized["attention_mask"].sum(dim=1).to(torch.int32), - } - ) - vllm_policy.prepare_for_generation() - outputs = vllm_policy.generate(test_input_data, greedy=True) - output_ids = outputs["output_ids"] - generated_texts = test_tokenizer.batch_decode(output_ids, skip_special_tokens=True) - print_colored(f"GENERATED TEXT: {generated_texts}") - - -def test_vllm_lora_generation(cluster, tokenizer): - """Test vLLM LoRA refit with sync engine and colocated setup.""" - vllm_config = deepcopy(basic_vllm_test_config) - vllm_config["vllm_cfg"]["lora_cfg"] = deepcopy(basic_lora_test_config) - vllm_config["vllm_cfg"]["lora_cfg"]["enabled"] = True - vllm_config["vllm_cfg"]["async_engine"] = False - vllm_config = configure_generation_config(vllm_config, tokenizer) - - print_colored("CREATING VLLM POLICY") - vllm_policy = VllmGeneration(cluster, vllm_config) - vllm_policy.prepare_for_generation() - - # print_colored("GENERATING TEXT") - # prompts = [ - # "What is the largest number, all of whose digits are 1 or 4, and whose digits add up to 12?" - # ] - # test_tokenizer = get_tokenizer({"name": model_name}) - # tokenized = test_tokenizer( - # prompts, - # padding=True, - # truncation=True, - # max_length=256, - # return_tensors="pt", - # padding_side="right", - # ) - # test_input_data = BatchedDataDict( - # { - # "input_ids": tokenized["input_ids"], - # "input_lengths": tokenized["attention_mask"].sum(dim=1).to(torch.int32), - # } - # ) - # outputs = vllm_policy.generate(test_input_data, greedy=True) - # output_ids = outputs["output_ids"] - # generated_texts = test_tokenizer.batch_decode(output_ids, skip_special_tokens=True) - # print_colored(f"GENERATED TEXT: {generated_texts}") From 0de61c73b4f8471fa71fbdd7b469c1ed8479ae89 Mon Sep 17 00:00:00 2001 From: ruit Date: Sun, 4 Jan 2026 21:52:57 -0800 Subject: [PATCH 6/6] update default value Signed-off-by: ruit --- nemo_rl/models/generation/vllm/vllm_backend.py | 4 ++-- nemo_rl/models/policy/lm_policy.py | 2 +- nemo_rl/models/policy/workers/dtensor_policy_worker.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index efa3443595..e3f5d14008 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -250,8 +250,8 @@ def _apply_loaded_weights( def update_weights_via_ipc_zmq( self, lora_config: dict[str, Any] = {}, - refit_base_model_weights: bool = False, - refit_lora_weights: bool = True, + refit_base_model_weights: bool = True, + refit_lora_weights: bool = False, ) -> bool: """Receive and update model weights via ZMQ IPC socket. diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 948bad5192..0681afc599 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -767,7 +767,7 @@ def stream_weights_via_ipc_zmq( buffer_size_bytes: int, kv_scales: Optional[dict[str, float]] = None, refit_base_model_weights: bool = True, - refit_lora_weights: bool = True, + refit_lora_weights: bool = False, ) -> list[ray.ObjectRef]: """Send the weights for IPC handles via ZMQ socket.""" futures = self.worker_group.run_all_workers_single_data( diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index 7be97e21d1..cc4b6d61df 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -1685,7 +1685,7 @@ def stream_weights_via_ipc_zmq( ) -> None: """Stream model weights to peer process via ZMQ IPC socket.""" assert refit_base_model_weights and refit_lora_weights == False, ( - "dtensor v1 not support lora. refit_lora_weights must be False" + f"dtensor v1 not support lora. refit_lora_weights must be False, but got refit_lora_weights={refit_lora_weights} and refit_base_model_weights={refit_base_model_weights}" ) if kv_scales is not None: raise NotImplementedError(