diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 1bc7df9810..7cacd7ae97 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -178,17 +178,88 @@ def make_eagle_supervised_data_module( class EagleTrainerWithAccLog(Trainer): """Wrapper around Trainer that logs training accuracy.""" + def __init__( + self, + *args, + lora_lr_multiplier: float = 1.0, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.lora_lr_multiplier = lora_lr_multiplier + + def create_optimizer(self): + """Override to give LoRA parameters a higher learning rate.""" + super().create_optimizer() + if self.lora_lr_multiplier != 1.0: + lora_ids = { + id(p) for n, p in self.model.named_parameters() if "lora_" in n and p.requires_grad + } + if lora_ids: + new_groups = [] + for group in self.optimizer.param_groups: + lora = [p for p in group["params"] if id(p) in lora_ids] + others = [p for p in group["params"] if id(p) not in lora_ids] + if lora and others: + new_groups.append({**group, "params": others}) + new_groups.append( + {**group, "params": lora, "lr": group["lr"] * self.lora_lr_multiplier} + ) + elif lora: + new_groups.append({**group, "lr": group["lr"] * self.lora_lr_multiplier}) + else: + new_groups.append(group) + self.optimizer.param_groups = new_groups + return self.optimizer + def compute_loss(self, *args, **kwargs): - """Override compute_loss to save train accs in trainer state.""" + """Override compute_loss to save train accs and per-component losses in trainer state.""" if not hasattr(self.state, "training_accs"): self.state.training_accs = [] + if not hasattr(self.state, "component_losses"): + self.state.component_losses = {"eagle": [], "preservation": []} kwargs.pop("num_items_in_batch", None) loss, outputs = super().compute_loss(return_outputs=True, *args, **kwargs) - if hasattr(outputs, "train_acc"): + if hasattr(outputs, "train_acc") and any(outputs.train_acc): self.state.training_accs.append(outputs.train_acc) + # Track per-component losses + for key, attr in [ + ("eagle", "eagle_loss"), + ("preservation", "preservation_loss"), + ]: + val = getattr(outputs, attr, None) + if val is not None: + self.state.component_losses[key].append(val.item()) return loss +class LoRAWarmupCallback(TrainerCallback): + """Manages LoRA warmup: freezes LoRA during warmup, unfreezes after.""" + + def __init__(self, warmup_steps: int): + self.warmup_steps = warmup_steps + self._activated = False + + def on_step_begin(self, args, state, control, **kwargs): + """Check if warmup is over and activate LoRA co-training.""" + if self._activated: + return control + if state.global_step >= self.warmup_steps: + model = kwargs["model"] + # Unwrap DDP/FSDP if needed + raw_model = model.module if hasattr(model, "module") else model + if hasattr(raw_model, "_lora_cotraining_active"): + raw_model._lora_cotraining_active = True + # Unfreeze LoRA parameters + for name, param in raw_model._base_model.named_parameters(): + if "lora_" in name: + param.requires_grad = True + print_rank_0( + f"Step {state.global_step}: LoRA warmup complete, enabling co-training." + ) + self._activated = True + return control + + class EagleTrainingPlot(TrainerCallback): """Callback that plot training acc and AR during training.""" @@ -230,8 +301,16 @@ def on_log(self, args, state, control, **kwargs): if self.estimate_ar: wandb.log({"estimated_training_ar": est_ar}, step=state.global_step) - # reset training_accs + # Log per-component losses + if hasattr(state, "component_losses"): + for key, vals in state.component_losses.items(): + if vals: + wandb.log({f"{key}_loss": np.mean(vals)}, step=state.global_step) + + # reset training_accs and component_losses state.training_accs = [] + if hasattr(state, "component_losses"): + state.component_losses = {"eagle": [], "preservation": []} return control def on_step_end(self, args, state, control, **kwargs): @@ -240,6 +319,7 @@ def on_step_end(self, args, state, control, **kwargs): return control if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0: print_rank_0("Running AR validation...") + torch.cuda.empty_cache() try: ars = validate_ar( model=kwargs["model"], diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index c15b97bdaa..65af978f68 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -134,6 +134,38 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi FSDP="${1#*=}" ;; + --eagle_base_lora_rank*) + if [[ "$1" != *=* ]]; then shift; fi + EAGLE_BASE_LORA_RANK="${1#*=}" + ;; + --eagle_base_lora_alpha*) + if [[ "$1" != *=* ]]; then shift; fi + EAGLE_BASE_LORA_ALPHA="${1#*=}" + ;; + --eagle_base_lora_target_modules*) + if [[ "$1" != *=* ]]; then shift; fi + EAGLE_BASE_LORA_TARGET_MODULES="${1#*=}" + ;; + --eagle_base_lora_preservation_loss_weight*) + if [[ "$1" != *=* ]]; then shift; fi + EAGLE_BASE_LORA_PRESERVATION_LOSS_WEIGHT="${1#*=}" + ;; + --eagle_base_lora_lr_multiplier*) + if [[ "$1" != *=* ]]; then shift; fi + EAGLE_BASE_LORA_LR_MULTIPLIER="${1#*=}" + ;; + --eagle_base_lora_warmup_steps*) + if [[ "$1" != *=* ]]; then shift; fi + EAGLE_BASE_LORA_WARMUP_STEPS="${1#*=}" + ;; + --eagle_base_lora_logits_detach_prob*) + if [[ "$1" != *=* ]]; then shift; fi + EAGLE_BASE_LORA_LOGITS_DETACH_PROB="${1#*=}" + ;; + --eagle_base_lora*) + if [[ "$1" != *=* ]]; then shift; fi + EAGLE_BASE_LORA="${1#*=}" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -184,6 +216,14 @@ DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""} MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"} DISABLE_TORCH_COMPILE=${DISABLE_TORCH_COMPILE:-"False"} NUM_TTT_STEPS=${NUM_TTT_STEPS:-3} +EAGLE_BASE_LORA=${EAGLE_BASE_LORA:-"False"} +EAGLE_BASE_LORA_RANK=${EAGLE_BASE_LORA_RANK:-64} +EAGLE_BASE_LORA_ALPHA=${EAGLE_BASE_LORA_ALPHA:-16.0} +EAGLE_BASE_LORA_TARGET_MODULES=${EAGLE_BASE_LORA_TARGET_MODULES:-""} +EAGLE_BASE_LORA_PRESERVATION_LOSS_WEIGHT=${EAGLE_BASE_LORA_PRESERVATION_LOSS_WEIGHT:-1.0} +EAGLE_BASE_LORA_LR_MULTIPLIER=${EAGLE_BASE_LORA_LR_MULTIPLIER:-1.0} +EAGLE_BASE_LORA_WARMUP_STEPS=${EAGLE_BASE_LORA_WARMUP_STEPS:-0} +EAGLE_BASE_LORA_LOGITS_DETACH_PROB=${EAGLE_BASE_LORA_LOGITS_DETACH_PROB:-0.5} USE_FAKE_BASE_FOR_OFFLINE=${USE_FAKE_BASE_FOR_OFFLINE:-"False"} TRUST_REMOTE_CODE=${TRUST_REMOTE_CODE:-"False"} @@ -218,6 +258,21 @@ else VLM_ARGS="" fi +if [[ "$EAGLE_BASE_LORA" == "True" ]]; then + LORA_ARGS="--eagle_base_lora True \ + --eagle_base_lora_rank $EAGLE_BASE_LORA_RANK \ + --eagle_base_lora_alpha $EAGLE_BASE_LORA_ALPHA \ + --eagle_base_lora_preservation_loss_weight $EAGLE_BASE_LORA_PRESERVATION_LOSS_WEIGHT \ + --eagle_base_lora_lr_multiplier $EAGLE_BASE_LORA_LR_MULTIPLIER \ + --eagle_base_lora_warmup_steps $EAGLE_BASE_LORA_WARMUP_STEPS \ + --eagle_base_lora_logits_detach_prob $EAGLE_BASE_LORA_LOGITS_DETACH_PROB" + if [[ "$EAGLE_BASE_LORA_TARGET_MODULES" != "" ]]; then + LORA_ARGS="$LORA_ARGS --eagle_base_lora_target_modules $EAGLE_BASE_LORA_TARGET_MODULES" + fi +else + LORA_ARGS="" +fi + if [[ "$TOTAL_GPU" -gt 1 && "$FSDP" == "True" ]]; then #Use FSDP2 when multi GPU available FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json" @@ -283,6 +338,7 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai --cp_size $CP_SIZE \ --dp_shard_size $DP_SHARD_SIZE \ --num_ttt_steps $NUM_TTT_STEPS \ + $LORA_ARGS \ " start_time=$(date +%s) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 3369d399c2..11793916de 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -40,6 +40,7 @@ from eagle_utils import ( EagleTrainerWithAccLog, EagleTrainingPlot, + LoRAWarmupCallback, make_eagle_supervised_data_module, patch_ring_attention_for_ttt, ) @@ -142,6 +143,68 @@ class EagleArguments: default=3, metadata={"help": "Number of train-time-test steps to use during training."}, ) + eagle_base_lora: bool = field( + default=False, + metadata={ + "help": ( + "Whether to add LoRA adapters to the base model for co-training with the EAGLE " + "draft module. Requires the `peft` library. Incompatible with offline training." + ) + }, + ) + eagle_base_lora_rank: int = field( + default=64, + metadata={"help": "LoRA rank for the base model adapters."}, + ) + eagle_base_lora_alpha: float = field( + default=16.0, + metadata={"help": "LoRA alpha (scaling) for the base model adapters."}, + ) + eagle_base_lora_target_modules: str = field( + default=None, + metadata={ + "help": ( + "Comma-separated list of module name patterns to apply LoRA to in the base model " + "(e.g. 'q_proj,v_proj'). Defaults to peft's default target modules." + ) + }, + ) + eagle_base_lora_preservation_loss_weight: float = field( + default=1.0, + metadata={ + "help": ( + "Weight for the preservation loss that minimizes KL divergence between the " + "LoRA-adapted base model output and the original base model output." + ) + }, + ) + eagle_base_lora_lr_multiplier: float = field( + default=1.0, + metadata={ + "help": ( + "Learning rate multiplier for LoRA parameters relative to the base learning rate." + ) + }, + ) + eagle_base_lora_warmup_steps: int = field( + default=0, + metadata={ + "help": ( + "Number of warmup steps where LoRA is frozen and only the EAGLE draft head trains. " + "After warmup, LoRA is enabled for co-training." + ) + }, + ) + eagle_base_lora_logits_detach_prob: float = field( + default=0.5, + metadata={ + "help": ( + "After warmup, probability of detaching base logits each step. Acts as dropout " + "regularization on the eagle-loss-to-LoRA gradient path, preventing LoRA from " + "degenerating. 1.0 = always detach, 0.0 = never detach." + ) + }, + ) def train(): @@ -217,6 +280,11 @@ def train(): json.load(open(eagle_args.eagle_config)) if eagle_args.eagle_config else {} ) + lora_target_modules = ( + eagle_args.eagle_base_lora_target_modules.split(",") + if eagle_args.eagle_base_lora_target_modules + else None + ) config = { "eagle_decoder_type": eagle_args.eagle_decoder_type, "eagle_offline": use_offline_training, @@ -224,6 +292,13 @@ def train(): "eagle_use_torch_compile": not eagle_args.disable_torch_compile, "eagle_ttt_steps": eagle_args.num_ttt_steps, "eagle_architecture_config": custom_config, + "eagle_base_lora": eagle_args.eagle_base_lora, + "eagle_base_lora_rank": eagle_args.eagle_base_lora_rank, + "eagle_base_lora_alpha": eagle_args.eagle_base_lora_alpha, + "eagle_base_lora_target_modules": lora_target_modules, + "eagle_base_lora_preservation_loss_weight": eagle_args.eagle_base_lora_preservation_loss_weight, + "eagle_base_lora_warmup_steps": eagle_args.eagle_base_lora_warmup_steps, + "eagle_base_lora_logits_detach_prob": eagle_args.eagle_base_lora_logits_detach_prob, } mtsp.convert(model, [("eagle", config)]) @@ -239,17 +314,36 @@ def train(): else: raise Exception(f"{training_args.mode} is not supported!") + # Move any remaining CPU buffers to CUDA so DDP (NCCL-only) can broadcast + # them. We iterate named_buffers and reassign via the owning module to + # keep the module tree consistent. Parameters are left on CPU — the HF + # Trainer will move them during init. + if torch.cuda.is_available(): + _target_dev = torch.device("cuda", 0) + for name, buf in list(model.named_buffers()): + if buf.device.type == "cpu": + parts = name.split(".") + mod = model + for p in parts[:-1]: + mod = getattr(mod, p) + setattr(mod, parts[-1], buf.to(_target_dev)) + print_rank_0("Loading dataset...") if training_args.mode == "eagle3": data_module = make_eagle_supervised_data_module( tokenizer, data_args, train_len=training_args.training_seq_len ) + callbacks = [EagleTrainingPlot(training_args.ar_validate_steps, training_args.estimate_ar)] + if eagle_args.eagle_base_lora and eagle_args.eagle_base_lora_warmup_steps > 0: + callbacks.append(LoRAWarmupCallback(eagle_args.eagle_base_lora_warmup_steps)) + trainer = EagleTrainerWithAccLog( model=model, processing_class=tokenizer, args=training_args, - callbacks=[EagleTrainingPlot(training_args.ar_validate_steps, training_args.estimate_ar)], + callbacks=callbacks, + lora_lr_multiplier=eagle_args.eagle_base_lora_lr_multiplier, **data_module, ) diff --git a/examples/speculative_decoding/requirements.txt b/examples/speculative_decoding/requirements.txt index 6324bac62b..e7b83ab2dd 100644 --- a/examples/speculative_decoding/requirements.txt +++ b/examples/speculative_decoding/requirements.txt @@ -1,2 +1,3 @@ accelerate==1.12.0 +peft==0.18.1 transformers==5.0.0rc1 diff --git a/examples/speculative_decoding/scripts/merge_lora.py b/examples/speculative_decoding/scripts/merge_lora.py new file mode 100644 index 0000000000..8db701690b --- /dev/null +++ b/examples/speculative_decoding/scripts/merge_lora.py @@ -0,0 +1,169 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Merge LoRA weights from an exported EAGLE checkpoint into the base model and save. + +Usage: + python merge_lora.py \ + --base_model_path /path/to/original/base/model \ + --exported_lora_dir /path/to/exported/eagle/checkpoint \ + --output_path /path/to/merged/output + +The exported checkpoint (from export_hf_checkpoint.py) contains +lora_adapter_model.safetensors and lora_adapter_config.json. This script +loads the original base model, applies the trained LoRA adapters, merges +them into the base weights, and saves the fused model + tokenizer. +""" + +import argparse +import json +from pathlib import Path + +from safetensors.torch import load_file +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Merge LoRA weights from an exported EAGLE checkpoint into the base model." + ) + parser.add_argument( + "--base_model_path", + type=str, + required=True, + help="Path to the original base model (HF model name or local path).", + ) + parser.add_argument( + "--exported_lora_dir", + type=str, + required=True, + help="Path to the exported EAGLE checkpoint containing lora_adapter_model.safetensors.", + ) + parser.add_argument( + "--output_path", + type=str, + required=True, + help="Directory to save the merged (fused) base model.", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + lora_dir = Path(args.exported_lora_dir) + + # Verify exported files exist + config_path = lora_dir / "lora_adapter_config.json" + weights_path = lora_dir / "lora_adapter_model.safetensors" + if not config_path.exists() or not weights_path.exists(): + raise FileNotFoundError( + f"Expected lora_adapter_config.json and lora_adapter_model.safetensors " + f"in {lora_dir}. Run export_hf_checkpoint.py first." + ) + + with open(config_path) as f: + lora_config_dict = json.load(f) + lora_sd = load_file(weights_path) + print(f"Loaded {len(lora_sd)} LoRA tensors from {lora_dir}") + print(f" Sample keys: {list(lora_sd.keys())[:4]}") + + # Load the original base model + print(f"Loading base model from {args.base_model_path}...") + model = AutoModelForCausalLM.from_pretrained( + args.base_model_path, torch_dtype="auto", device_map="cpu", trust_remote_code=True + ) + tokenizer = AutoTokenizer.from_pretrained(args.base_model_path, trust_remote_code=True) + + # Create PeftModel by injecting LoRA layers from config + print("Injecting LoRA layers...") + from peft import LoraConfig, get_peft_model + + lora_config = LoraConfig( + r=lora_config_dict["r"], + lora_alpha=lora_config_dict["lora_alpha"], + target_modules=lora_config_dict["target_modules"], + bias=lora_config_dict.get("bias", "none"), + ) + model = get_peft_model(model, lora_config) + + # Build key mapping: exported keys -> PeftModel state dict keys + # PeftModel keys look like: base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight + # Exported keys look like: model.layers.0.self_attn.q_proj.lora_A.default.weight + peft_lora_keys = {k for k in model.state_dict() if ".lora_A." in k or ".lora_B." in k} + print(f" PeftModel has {len(peft_lora_keys)} LoRA parameters") + print(f" Sample PeftModel keys: {sorted(peft_lora_keys)[:4]}") + + # Determine prefix by matching the first exported key against PeftModel keys + sample_export_key = next(iter(lora_sd)) + matching_peft_keys = [k for k in peft_lora_keys if k.endswith(sample_export_key)] + if matching_peft_keys: + prefix = matching_peft_keys[0][: -len(sample_export_key)] + print(f" Detected key prefix: '{prefix}'") + else: + # Try without .default. segment (in case export format differs) + prefix = "" + print(" WARNING: Could not auto-detect prefix, trying direct key match") + + # Load weights into PeftModel + print("Loading LoRA weights into PeftModel...") + peft_sd = model.state_dict() + loaded_count = 0 + missing_keys = [] + for export_key, tensor in lora_sd.items(): + peft_key = prefix + export_key + if peft_key in peft_sd: + peft_sd[peft_key] = tensor + loaded_count += 1 + else: + missing_keys.append(export_key) + + if missing_keys: + print(f" WARNING: {len(missing_keys)} exported keys not found in PeftModel:") + for k in missing_keys[:10]: + print(f" {k}") + if len(missing_keys) > 10: + print(f" ... and {len(missing_keys) - 10} more") + + if loaded_count == 0: + raise RuntimeError( + "No exported LoRA keys matched PeftModel keys. " + "Check export format vs PeftModel key naming." + ) + + model.load_state_dict(peft_sd) + print(f" Successfully loaded {loaded_count}/{len(lora_sd)} LoRA tensors") + + # Verify lora_B weights are non-zero (B is init'd to zero, so non-zero means loaded) + lora_b_norms = [v.norm().item() for k, v in model.state_dict().items() if ".lora_B." in k] + if not lora_b_norms or all(n == 0 for n in lora_b_norms): + raise RuntimeError("LoRA-B weights are all zero — adapter loading failed.") + print( + f" Verified: {len(lora_b_norms)} LoRA-B matrices " + f"(mean norm={sum(lora_b_norms) / len(lora_b_norms):.4f})" + ) + + # Merge LoRA into base weights and remove adapter wrappers + model = model.merge_and_unload() + print("LoRA merged successfully.") + + # Save + print(f"Saving merged model to {args.output_path}...") + model.save_pretrained(args.output_path) + tokenizer.save_pretrained(args.output_path) + print(f"Done! Merged model saved to {args.output_path}") + + +if __name__ == "__main__": + main() diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index aca19a1580..d8cb3f862c 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -112,11 +112,17 @@ def _check_valid_sd(self, export_sd: dict): "llama": LLAMA_EAGLE_SINGLE_LAYER, "kimik2": KIMIK2_EAGLE_SINGLE_LAYER, }[self.eagle_decoder_type] + # fc and hidden_norm are only present when use_aux_hidden_state=True + use_aux = getattr(self.model.eagle_config, "use_aux_hidden_state", False) + aux_only_keys = {"fc", "layers.0.hidden_norm"} + required_keys = set(expected_keys_single_layer["required"]) + if not use_aux: + required_keys -= aux_only_keys # Check that export sd has required keys - for key in expected_keys_single_layer["required"]: + for key in required_keys: assert f"{key}.weight" in export_sd, f"Missing required key: {key}.weight" for i in range(1, self.num_hidden_layers): - for key in expected_keys_single_layer["required"] - { + for key in required_keys - { "layers.0.hidden_norm", "layers.0.input_layernorm", "norm", @@ -185,6 +191,39 @@ def _get_config_from_draft_or_base(key: str, model: nn.Module): return template_config + def _export_lora(self, export_dir: Path, full_sd: dict): + """Export base model LoRA adapter weights alongside the eagle module artifacts.""" + from peft import LoraConfig + + lora_sd = {k: v for k, v in full_sd.items() if ".lora_A." in k or ".lora_B." in k} + if not lora_sd: + raise RuntimeError( + "No LoRA adapter tensors found in the model state dict. " + "Ensure eagle_base_lora=True and the model was converted with LoRA adapters." + ) + # Rename keys to PeftModel format: lora_A.weight -> lora_A.default.weight + lora_sd = { + re.sub(r"(lora_[AB])\.weight$", r"\1.default.weight", k): v for k, v in lora_sd.items() + } + save_file(lora_sd, export_dir / "lora_adapter_model.safetensors") + + # Infer target modules from the exported LoRA keys (e.g., "q_proj", "v_proj") + # Keys are like: model.layers.0.self_attn.q_proj.lora_A.default.weight + target_modules = sorted({k.split(".")[-4] for k in lora_sd if ".lora_A." in k}) + lora_config = LoraConfig( + r=self.model.eagle_base_lora_rank, + lora_alpha=self.model.eagle_base_lora_alpha, + target_modules=target_modules, + bias="none", + ) + with open(export_dir / "lora_adapter_config.json", "w") as f: + json.dump( + lora_config.to_dict(), + f, + indent=4, + default=lambda o: sorted(o) if isinstance(o, set) else o, + ) + def export(self, export_dir: Path | str, dtype: torch.dtype | None = None): """Export the model to the deployment format.""" # Make export dir @@ -215,6 +254,10 @@ def export(self, export_dir: Path | str, dtype: torch.dtype | None = None): with open(f"{export_dir}/hf_quant_config.json", "w") as file: json.dump(hf_quant_config, file, indent=4) + # Export LoRA adapter weights separately + if getattr(self.model, "eagle_base_lora", False): + self._export_lora(export_dir, full_sd) + class EagleMedusaExporter(EagleExporter): """Draft model exporter for EagleMedusa.""" diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 69491c6599..b03164e962 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -120,3 +120,55 @@ class EagleConfig(ModeloptBaseConfig): default=False, description="Whether to enable NVTX ranges for profiling eagle forward/loss methods.", ) + + eagle_base_lora: bool = ModeloptField( + default=False, + description=( + "Whether to add LoRA adapters to the base model for co-training with the EAGLE module. " + "Requires the `peft` library. Incompatible with eagle_offline=True." + ), + ) + + eagle_base_lora_rank: int = ModeloptField( + default=64, + description="LoRA rank for the base model adapters.", + ) + + eagle_base_lora_alpha: float = ModeloptField( + default=16.0, + description="LoRA alpha (scaling) for the base model adapters.", + ) + + eagle_base_lora_target_modules: list | None = ModeloptField( + default=None, + description=( + "List of module name patterns to apply LoRA to in the base model " + "(e.g. ['q_proj', 'v_proj']). None uses peft defaults." + ), + ) + + eagle_base_lora_preservation_loss_weight: float = ModeloptField( + default=0.1, + description=( + "Weight for the preservation loss that minimizes the KL divergence between " + "the LoRA-adapted base model output and the original base model output." + ), + ) + + eagle_base_lora_warmup_steps: int = ModeloptField( + default=0, + description=( + "Number of warmup steps where LoRA is frozen and only the EAGLE draft head trains. " + "After warmup, LoRA is enabled for co-training." + ), + ) + + eagle_base_lora_logits_detach_prob: float = ModeloptField( + default=0.5, + description=( + "After warmup, probability of detaching base_output_softmax_logits each step. " + "Acts as dropout regularization on the eagle-loss-to-LoRA gradient path through " + "logits, preventing LoRA from degenerating to maximize EAGLE accuracy at the cost " + "of base model quality. 1.0 = always detach (no logits gradient), 0.0 = never detach." + ), + ) diff --git a/modelopt/torch/speculative/eagle/eagle_model.py b/modelopt/torch/speculative/eagle/eagle_model.py index e2a08c5252..cc5b0f4035 100644 --- a/modelopt/torch/speculative/eagle/eagle_model.py +++ b/modelopt/torch/speculative/eagle/eagle_model.py @@ -41,3 +41,12 @@ def modify( self.eagle_mix_hidden_states = config.eagle_mix_hidden_states self.eagle_use_torch_compile = config.eagle_use_torch_compile self.eagle_enable_nvtx = config.eagle_enable_nvtx + self.eagle_base_lora = config.eagle_base_lora + self.eagle_base_lora_rank = config.eagle_base_lora_rank + self.eagle_base_lora_alpha = config.eagle_base_lora_alpha + self.eagle_base_lora_target_modules = config.eagle_base_lora_target_modules + self.eagle_base_lora_preservation_loss_weight = ( + config.eagle_base_lora_preservation_loss_weight + ) + self.eagle_base_lora_warmup_steps = config.eagle_base_lora_warmup_steps + self.eagle_base_lora_logits_detach_prob = config.eagle_base_lora_logits_detach_prob diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 8561a390fc..e3dcf568ec 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -512,12 +512,13 @@ def _set_default_aux_hidden_state_layers(self): def _collect_aux_hidden_states_forward_hook(self, module, input, output) -> None: """Collect auxiliary hidden states from base model intermediate layers, save them in attribute.""" - hidden_states = ( - output.clone().detach() - if isinstance(output, torch.Tensor) - else output[0].clone().detach() - ) - self._aux_hidden_states.append(hidden_states) + raw = output if isinstance(output, torch.Tensor) else output[0] + # With LoRA co-training (after warmup), keep grad so EAGLE loss + # backpropagates through hidden states to LoRA. + if self.training and getattr(self, "_lora_cotraining_active", False): + self._aux_hidden_states.append(raw.clone()) + else: + self._aux_hidden_states.append(raw.clone().detach()) def pop_and_gather_aux_hiddens(self): """Pop auxiliary hidden states from base model and gather them on the draft model device.""" @@ -548,6 +549,43 @@ def _get_eagle_device(self): base_model_last_layer = self._base_model.layers[-1] return next(base_model_last_layer.parameters()).device + def _inject_base_lora(self): + """Inject HF PEFT LoRA adapters into the base model in-place and unfreeze them.""" + from peft import LoraConfig + from peft.mapping import inject_adapter_in_model + + target_modules = self.eagle_base_lora_target_modules or None + lora_config = LoraConfig( + r=self.eagle_base_lora_rank, + lora_alpha=self.eagle_base_lora_alpha, + target_modules=target_modules, + bias="none", + ) + inject_adapter_in_model(lora_config, self._base_model, adapter_name="default") + # Unfreeze LoRA parameters unless we have a warmup phase + freeze_lora = self.eagle_base_lora_warmup_steps > 0 + for name, param in self._base_model.named_parameters(): + if "lora_" in name: + param.requires_grad = not freeze_lora + + def _set_base_lora_enabled(self, enabled: bool) -> None: + """Enable or disable LoRA adapters in the base model.""" + from peft.tuners.lora import LoraLayer + + for module in self._base_model.modules(): + if isinstance(module, LoraLayer): + module.enable_adapters(enabled) + + def _preservation_loss( + self, ref_logits: torch.Tensor, lora_logits: torch.Tensor + ) -> torch.Tensor: + """KL divergence encouraging LoRA output to stay close to the original base model. + + KL(softmax(ref) || log_softmax(lora)) weighted by eagle_base_lora_preservation_loss_weight. + """ + loss = nn.Softmax(dim=-1)(ref_logits.detach()) * nn.LogSoftmax(dim=-1)(lora_logits) + return -loss.sum(dim=-1).mean() * self.eagle_base_lora_preservation_loss_weight + def modify( self, config, @@ -606,6 +644,16 @@ def modify( if layer_idx in self.eagle_config.eagle_aux_hidden_state_layer_ids: layer.register_forward_hook(self._collect_aux_hidden_states_forward_hook) + # Inject HF PEFT LoRA adapters into the base model for co-training + if self.eagle_base_lora: + if self.eagle_offline: + raise ValueError("eagle_base_lora is incompatible with eagle_offline=True") + self._inject_base_lora() + # Whether LoRA co-training is active this step. Controlled by the + # trainer based on warmup schedule. When False, LoRA params are + # frozen and logits are always detached (eagle-only training). + self._lora_cotraining_active = self.eagle_base_lora_warmup_steps == 0 + # delete base model layers for offline training if self.eagle_offline: self._base_model._modules.pop("layers") @@ -735,7 +783,16 @@ def _prepare_eagle_inputs( if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: base_model_logits = self._map_logits_to_draft_vocab(base_model_logits) base_output_predict_tok = base_model_logits.argmax(dim=-1).detach() - base_output_softmax_logits = torch.softmax(base_model_logits, dim=2).detach() + base_output_softmax_logits = torch.softmax(base_model_logits, dim=2) + # After LoRA warmup, stochastically detach logits — acts as dropout + # regularization on the eagle-loss-to-LoRA gradient path, preventing + # LoRA from degenerating to maximize EAGLE acc at cost of base quality. + # During warmup or when LoRA is off, always detach. + lora_active = getattr(self, "_lora_cotraining_active", False) and self.training + if lora_active and torch.rand(1).item() >= self.eagle_base_lora_logits_detach_prob: + pass # keep gradients flowing through logits to LoRA + else: + base_output_softmax_logits = base_output_softmax_logits.detach() return ( eagle_input_embeds, @@ -751,7 +808,9 @@ def _compute_ttt_attention_mask( ) -> BlockMask | torch.Tensor: """Return TTT attention_mask tensor of type BlockMask or Tensor depends on eagle attn impl.""" msk_func = get_ttt_msk_func(seq_length, ttt_step) - dtypemin = torch.finfo(self._base_llm_config.dtype).min + dtypemin = torch.finfo( + getattr(self._base_llm_config, "dtype", None) or torch.get_default_dtype() + ).min q_len = seq_length kv_len = seq_length * (1 + ttt_step) if self.eagle_config._attn_implementation == "flex_attention": @@ -767,7 +826,10 @@ def _compute_ttt_attention_mask( torch.arange(kv_len).view(1, 1, 1, kv_len), ).to(self.device) tensor_mask = torch.full_like( - tensor_mask, 0, dtype=self._base_llm_config.dtype, device=self.device + tensor_mask, + 0, + dtype=getattr(self._base_llm_config, "dtype", None) or torch.get_default_dtype(), + device=self.device, ).masked_fill(~tensor_mask, dtypemin) return tensor_mask @@ -782,32 +844,50 @@ def _base_model_forward( labels, **kwargs, ): - with torch.no_grad() if freeze_base_model else contextlib.nullcontext(): - outputs = super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - output_hidden_states=True, - **kwargs, - ) - past_key_values = getattr(outputs, "past_key_values", None) - base_input_embeds = outputs.hidden_states[0] - base_model_hidden_states = outputs.hidden_states[-1] - base_model_logits = outputs.logits + def _run_forward(no_grad): + with torch.no_grad() if no_grad else contextlib.nullcontext(): + return super(HFEagleModel, self).forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_hidden_states=True, + **kwargs, + ) - # Optionally, compute base model loss when we want to tune the base model. + # With LoRA co-training, run a reference forward (LoRA disabled, no grad) + # to get the original base model logits for preservation loss, then run + # the main forward with LoRA enabled and gradients flowing. + # During warmup (_lora_cotraining_active=False), skip entirely. + lora_active = getattr(self, "_lora_cotraining_active", False) and self.training + ref_logits = None + if lora_active: + self._set_base_lora_enabled(False) + try: + ref_logits = _run_forward(no_grad=True).logits + finally: + if hasattr(self, "_aux_hidden_states"): + self._aux_hidden_states.clear() + self._set_base_lora_enabled(True) + + outputs = _run_forward(no_grad=freeze_base_model and not lora_active) + past_key_values = getattr(outputs, "past_key_values", None) + base_model_logits = outputs.logits + + if ref_logits is not None: + base_model_loss = self._preservation_loss(ref_logits, base_model_logits) + elif not freeze_base_model and labels is not None: + loss_fct = CrossEntropyLoss() + base_model_loss = loss_fct( + base_model_logits.view(-1, base_model_logits.shape[-1]), labels.view(-1) + ) + else: base_model_loss = None - if not freeze_base_model and labels is not None: # Base model loss - loss_fct = CrossEntropyLoss() - loss_logits = base_model_logits.view(-1, base_model_logits.shape[-1]) - labels = labels.view(-1) - base_model_loss = loss_fct(loss_logits, labels) return EagleBaseModelOutput( - input_embeds=base_input_embeds, + input_embeds=outputs.hidden_states[0], aux_hiddens=self.pop_and_gather_aux_hiddens(), - out_hiddens=base_model_hidden_states, + out_hiddens=outputs.hidden_states[-1], logits=base_model_logits, loss=base_model_loss, ), past_key_values @@ -997,7 +1077,7 @@ def forward( # Slice by actual number of steps taken, in case of early return train_accs = train_accs[:, : ttt_step + 1].tolist() - # Merge base model loss and eagle loss + # Merge eagle loss and preservation loss (if LoRA co-training) if base_outputs.loss is None and eagle_loss is None: loss = None assert not self.training, "At least one loss must be computed for training." @@ -1010,6 +1090,8 @@ def forward( past_key_values=past_key_values, hidden_states=base_outputs.out_hiddens, train_acc=train_accs, + eagle_loss=eagle_loss, + preservation_loss=base_outputs.loss if self.eagle_base_lora else None, ) def _eagle_loss( diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index 9e167c8dc9..ba28dfb7c1 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -477,7 +477,7 @@ def enable_cp_ttt_patch(): import modelopt.torch.speculative.plugins.transformers modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = True - with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): + with sdpa_kernel([SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]): try: yield finally: diff --git a/tests/unit/torch/speculative/plugins/test_hf_speculative_lora.py b/tests/unit/torch/speculative/plugins/test_hf_speculative_lora.py new file mode 100644 index 0000000000..3b9bf996a2 --- /dev/null +++ b/tests/unit/torch/speculative/plugins/test_hf_speculative_lora.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for EAGLE + LoRA co-training (eagle_base_lora feature).""" + +from copy import deepcopy + +import pytest +import torch +from _test_utils.torch.transformers_models import get_tiny_llama +from peft.tuners.lora import LoraLayer + +import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.eagle.default_config import default_eagle_config + +TINY_EAGLE_CFG = { + "num_hidden_layers": 1, + "intermediate_size": 32, + "num_attention_heads": 16, + "num_key_value_heads": 16, + "head_dim": 2, + "use_last_layernorm": True, + "use_aux_hidden_state": False, + "eagle_aux_hidden_state_layer_ids": [], +} + +EAGLE_LORA_CONFIG = { + "eagle_architecture_config": {**default_eagle_config, **TINY_EAGLE_CFG}, + "eagle_base_lora": True, + "eagle_base_lora_rank": 4, + "eagle_base_lora_alpha": 8.0, + "eagle_base_lora_target_modules": ["q_proj", "v_proj"], + "eagle_base_lora_preservation_loss_weight": 0.1, +} + + +@pytest.fixture +def lora_eagle_model(): + model = get_tiny_llama(num_hidden_layers=4) + mtsp.convert(model, mode=[("eagle", deepcopy(EAGLE_LORA_CONFIG))]) + return model + + +def test_lora_layers_injected(lora_eagle_model): + """LoRA adapters should be present in the base model after conversion.""" + lora_layers = [m for m in lora_eagle_model._base_model.modules() if isinstance(m, LoraLayer)] + assert len(lora_layers) > 0, "No LoRA layers found in base model" + + +def test_trainable_params(lora_eagle_model): + """Only LoRA and eagle_module params should be trainable; base model weights frozen.""" + for name, param in lora_eagle_model.named_parameters(): + is_lora = "lora_" in name + is_eagle = "eagle_module" in name + if is_lora or is_eagle: + assert param.requires_grad, f"Expected {name} to be trainable" + else: + assert not param.requires_grad, f"Expected {name} to be frozen" + + +def test_forward_returns_loss(lora_eagle_model): + """Forward pass should return a scalar loss containing preservation + eagle components.""" + lora_eagle_model.train() + seq_len = 8 + input_ids = torch.randint(0, lora_eagle_model.config.vocab_size, (1, seq_len)) + output = lora_eagle_model(input_ids=input_ids, labels=input_ids) + assert output.loss is not None + assert output.loss.ndim == 0, "Loss should be a scalar" + assert output.loss.item() > 0 + + +def test_eagle_offline_incompatible(): + """eagle_base_lora=True should raise when combined with eagle_offline=True.""" + model = get_tiny_llama(num_hidden_layers=4) + config = deepcopy(EAGLE_LORA_CONFIG) + config["eagle_offline"] = True + with pytest.raises(ValueError, match="eagle_base_lora is incompatible with eagle_offline"): + mtsp.convert(model, mode=[("eagle", config)]) + + +def test_export_lora_artifacts(lora_eagle_model, tmp_path): + """export() should produce lora_adapter_model.safetensors and lora_adapter_config.json.""" + export_dir = tmp_path / "eagle_export" + lora_eagle_model.get_exporter().export(export_dir) + + assert (export_dir / "model.safetensors").exists(), "Eagle model weights missing" + assert (export_dir / "lora_adapter_model.safetensors").exists(), "LoRA weights missing" + assert (export_dir / "lora_adapter_config.json").exists(), "LoRA config missing" diff --git a/tools/launcher/core.py b/tools/launcher/core.py index 40e6c94419..7004f7e662 100644 --- a/tools/launcher/core.py +++ b/tools/launcher/core.py @@ -268,7 +268,7 @@ def build_slurm_executor( container_image=slurm_config.container, container_mounts=container_mounts, array=slurm_config.array, - time="04:00:00", + time=slurm_config.time, mem="0", retries=0, packager=packager,