Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
c2846a3
Add LoRA co-training support for HF EAGLE speculative decoding
yeyu-nvidia Mar 17, 2026
6257ce3
Add peft to speculative_decoding example requirements
yeyu-nvidia Mar 17, 2026
a08826b
Fix AttributeError for eagle_base_lora attributes in EagleModel.modify()
yeyu-nvidia Mar 17, 2026
9dec4f9
Fix TINY_EAGLE_CFG to pass CPU unit tests
yeyu-nvidia Mar 17, 2026
dccbdb4
Fix CPU unit test failures for EAGLE LoRA co-training
yeyu-nvidia Mar 17, 2026
99a11d3
Fix CPU unit test failures for EAGLE LoRA co-training
yeyu-nvidia Mar 18, 2026
045d984
Make peft import lazy in hf_spec_export to avoid ModuleNotFoundError
yeyu-nvidia Mar 18, 2026
7b7ce40
Fix AttributeError for LlamaConfig.dtype in older transformers versions
yeyu-nvidia Mar 18, 2026
6a5f2c1
Fix LlamaConfig.dtype compatibility across transformers versions
yeyu-nvidia Mar 18, 2026
71b0de9
Fix second LlamaConfig.dtype access in _compute_ttt_attention_mask
yeyu-nvidia Mar 18, 2026
b6f4bfb
Fix dtype mismatch in _compute_ttt_attention_mask for float32 models
yeyu-nvidia Mar 18, 2026
21834dd
Fix ruff line-length formatting in _compute_ttt_attention_mask
yeyu-nvidia Mar 18, 2026
3325aba
Address PR review feedback for eagle_base_lora feature
yeyu-nvidia Mar 19, 2026
ea3875d
Address remaining PR review feedback
yeyu-nvidia Mar 19, 2026
8503349
Expose eagle_base_lora co-training args in example scripts
yeyu-nvidia Mar 19, 2026
d4ecf56
Fix case pattern ordering bug in launch_train.sh for eagle_base_lora …
yeyu-nvidia Mar 19, 2026
14057c4
Revert detach and raise preservation loss weight to 1.0 for LoRA co-t…
yeyu-nvidia Mar 19, 2026
6616175
Add LoRA LR multiplier and detach base logits in EAGLE loss
yeyu-nvidia Mar 20, 2026
4779595
style: apply ruff formatting to eagle_utils.py
yeyu-nvidia Mar 20, 2026
0e179f2
Revert preservation loss weight to 0.1
yeyu-nvidia Mar 20, 2026
57bdd75
Add LM loss as direct training signal for LoRA co-training
yeyu-nvidia Mar 20, 2026
e941fca
Free cached GPU memory before AR validation to avoid OOM
yeyu-nvidia Mar 20, 2026
78a2af9
Enable EAGLE loss gradient flow through aux_hiddens to LoRA
yeyu-nvidia Mar 20, 2026
4ca9927
Use scaled gradient leak for LoRA through aux hidden states
yeyu-nvidia Mar 20, 2026
230a763
Switch to alternating EAGLE/LoRA training phases
yeyu-nvidia Mar 23, 2026
f6c527f
Split LoRA training into separate EAGLE-loss and preservation phases
yeyu-nvidia Mar 23, 2026
e268677
Add merge_lora.py script to fuse trained LoRA weights into base model
yeyu-nvidia Mar 23, 2026
c22257b
Skip EAGLE forward in Phase C (preservation-only) to save compute
yeyu-nvidia Mar 23, 2026
2c90281
Log per-phase losses (eagle, lora_eagle, preservation) to wandb
yeyu-nvidia Mar 23, 2026
fa042f1
Fix inhomogeneous train_acc array from Phase C empty lists
yeyu-nvidia Mar 23, 2026
2407c8b
Fix merge_lora.py: use LoraLayer.delete_adapter instead of private API
yeyu-nvidia Mar 24, 2026
3e4d41a
Rewrite merge_lora.py to use PeftModel.merge_and_unload()
yeyu-nvidia Mar 24, 2026
51e7d57
Fix merge_lora.py key mismatch: keep model. prefix and add .default
yeyu-nvidia Mar 24, 2026
3c31fa2
Export LoRA keys in PeftModel-compatible format (.default suffix)
yeyu-nvidia Mar 24, 2026
d621c38
Fail merge_lora.py on missing LoRA adapter keys instead of silently c…
yeyu-nvidia Mar 24, 2026
b43bc70
Infer LoRA target_modules from exported keys in adapter config
yeyu-nvidia Mar 24, 2026
ddb6275
Fix target_modules inference: use [-4] index after .default rename
yeyu-nvidia Mar 24, 2026
e2d07bd
Fix merge_lora.py: add base_model.model. prefix for PeftModel compati…
yeyu-nvidia Mar 24, 2026
e15a042
Fix merge_lora.py for peft 0.18+: no prefix, strip .default from keys
yeyu-nvidia Mar 24, 2026
a4fc244
Fix merge_lora.py: suppress spurious peft warning, verify via norm check
yeyu-nvidia Mar 24, 2026
a74bf92
Fix merge_lora.py: keep .default in keys, verify lora_B norms
yeyu-nvidia Mar 24, 2026
f5374a4
Fix merge_lora.py to use explicit key mapping and pin peft==0.18.1
yeyu-nvidia Mar 24, 2026
0d6ad77
Simplify LoRA co-training to 2 phases with hidden-state gradient flow
yeyu-nvidia Mar 30, 2026
7b61e69
style: fix ruff formatting in aux hidden states hook
yeyu-nvidia Mar 30, 2026
4317669
Move CPU buffers to CUDA before Trainer init to fix DDP multi-node
yeyu-nvidia Mar 31, 2026
f242e8d
Simplify LoRA co-training to single phase: eagle + preservation every…
yeyu-nvidia Mar 31, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 55 additions & 3 deletions examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,57 @@ def make_eagle_supervised_data_module(
class EagleTrainerWithAccLog(Trainer):
"""Wrapper around Trainer that logs training accuracy."""

def __init__(
self,
*args,
lora_lr_multiplier: float = 1.0,
**kwargs,
):
super().__init__(*args, **kwargs)
self.lora_lr_multiplier = lora_lr_multiplier

def create_optimizer(self):
"""Override to give LoRA parameters a higher learning rate."""
super().create_optimizer()
if self.lora_lr_multiplier != 1.0:
lora_ids = {
id(p) for n, p in self.model.named_parameters() if "lora_" in n and p.requires_grad
}
if lora_ids:
new_groups = []
for group in self.optimizer.param_groups:
lora = [p for p in group["params"] if id(p) in lora_ids]
others = [p for p in group["params"] if id(p) not in lora_ids]
if lora and others:
new_groups.append({**group, "params": others})
new_groups.append(
{**group, "params": lora, "lr": group["lr"] * self.lora_lr_multiplier}
)
elif lora:
new_groups.append({**group, "lr": group["lr"] * self.lora_lr_multiplier})
else:
new_groups.append(group)
self.optimizer.param_groups = new_groups
return self.optimizer

def compute_loss(self, *args, **kwargs):
"""Override compute_loss to save train accs in trainer state."""
"""Override compute_loss to save train accs and per-component losses in trainer state."""
if not hasattr(self.state, "training_accs"):
self.state.training_accs = []
if not hasattr(self.state, "component_losses"):
self.state.component_losses = {"eagle": [], "preservation": []}
kwargs.pop("num_items_in_batch", None)
loss, outputs = super().compute_loss(return_outputs=True, *args, **kwargs)
if hasattr(outputs, "train_acc"):
if hasattr(outputs, "train_acc") and any(outputs.train_acc):
self.state.training_accs.append(outputs.train_acc)
# Track per-component losses
for key, attr in [
("eagle", "eagle_loss"),
("preservation", "preservation_loss"),
]:
val = getattr(outputs, attr, None)
if val is not None:
self.state.component_losses[key].append(val.item())
return loss


Expand Down Expand Up @@ -230,8 +273,16 @@ def on_log(self, args, state, control, **kwargs):
if self.estimate_ar:
wandb.log({"estimated_training_ar": est_ar}, step=state.global_step)

# reset training_accs
# Log per-component losses
if hasattr(state, "component_losses"):
for key, vals in state.component_losses.items():
if vals:
wandb.log({f"{key}_loss": np.mean(vals)}, step=state.global_step)

# reset training_accs and component_losses
state.training_accs = []
if hasattr(state, "component_losses"):
state.component_losses = {"eagle": [], "preservation": []}
return control

def on_step_end(self, args, state, control, **kwargs):
Expand All @@ -240,6 +291,7 @@ def on_step_end(self, args, state, control, **kwargs):
return control
if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0:
print_rank_0("Running AR validation...")
torch.cuda.empty_cache()
try:
ars = validate_ar(
model=kwargs["model"],
Expand Down
44 changes: 44 additions & 0 deletions examples/speculative_decoding/launch_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,30 @@ while [ $# -gt 0 ]; do
if [[ "$1" != *=* ]]; then shift; fi
FSDP="${1#*=}"
;;
--eagle_base_lora_rank*)
if [[ "$1" != *=* ]]; then shift; fi
EAGLE_BASE_LORA_RANK="${1#*=}"
;;
--eagle_base_lora_alpha*)
if [[ "$1" != *=* ]]; then shift; fi
EAGLE_BASE_LORA_ALPHA="${1#*=}"
;;
--eagle_base_lora_target_modules*)
if [[ "$1" != *=* ]]; then shift; fi
EAGLE_BASE_LORA_TARGET_MODULES="${1#*=}"
;;
--eagle_base_lora_preservation_loss_weight*)
if [[ "$1" != *=* ]]; then shift; fi
EAGLE_BASE_LORA_PRESERVATION_LOSS_WEIGHT="${1#*=}"
;;
--eagle_base_lora_lr_multiplier*)
if [[ "$1" != *=* ]]; then shift; fi
EAGLE_BASE_LORA_LR_MULTIPLIER="${1#*=}"
;;
--eagle_base_lora*)
if [[ "$1" != *=* ]]; then shift; fi
EAGLE_BASE_LORA="${1#*=}"
;;
*)
>&2 printf "Error: Invalid argument ${1#*=}\n"
exit 1
Expand Down Expand Up @@ -184,6 +208,12 @@ DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""}
MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"}
DISABLE_TORCH_COMPILE=${DISABLE_TORCH_COMPILE:-"False"}
NUM_TTT_STEPS=${NUM_TTT_STEPS:-3}
EAGLE_BASE_LORA=${EAGLE_BASE_LORA:-"False"}
EAGLE_BASE_LORA_RANK=${EAGLE_BASE_LORA_RANK:-64}
EAGLE_BASE_LORA_ALPHA=${EAGLE_BASE_LORA_ALPHA:-16.0}
EAGLE_BASE_LORA_TARGET_MODULES=${EAGLE_BASE_LORA_TARGET_MODULES:-""}
EAGLE_BASE_LORA_PRESERVATION_LOSS_WEIGHT=${EAGLE_BASE_LORA_PRESERVATION_LOSS_WEIGHT:-1.0}
EAGLE_BASE_LORA_LR_MULTIPLIER=${EAGLE_BASE_LORA_LR_MULTIPLIER:-1.0}

USE_FAKE_BASE_FOR_OFFLINE=${USE_FAKE_BASE_FOR_OFFLINE:-"False"}
TRUST_REMOTE_CODE=${TRUST_REMOTE_CODE:-"False"}
Expand Down Expand Up @@ -218,6 +248,19 @@ else
VLM_ARGS=""
fi

if [[ "$EAGLE_BASE_LORA" == "True" ]]; then
LORA_ARGS="--eagle_base_lora True \
--eagle_base_lora_rank $EAGLE_BASE_LORA_RANK \
--eagle_base_lora_alpha $EAGLE_BASE_LORA_ALPHA \
--eagle_base_lora_preservation_loss_weight $EAGLE_BASE_LORA_PRESERVATION_LOSS_WEIGHT \
--eagle_base_lora_lr_multiplier $EAGLE_BASE_LORA_LR_MULTIPLIER"
if [[ "$EAGLE_BASE_LORA_TARGET_MODULES" != "" ]]; then
LORA_ARGS="$LORA_ARGS --eagle_base_lora_target_modules $EAGLE_BASE_LORA_TARGET_MODULES"
fi
else
LORA_ARGS=""
fi

if [[ "$TOTAL_GPU" -gt 1 && "$FSDP" == "True" ]]; then
#Use FSDP2 when multi GPU available
FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json"
Expand Down Expand Up @@ -283,6 +326,7 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai
--cp_size $CP_SIZE \
--dp_shard_size $DP_SHARD_SIZE \
--num_ttt_steps $NUM_TTT_STEPS \
$LORA_ARGS \
"

start_time=$(date +%s)
Expand Down
68 changes: 68 additions & 0 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,49 @@ class EagleArguments:
default=3,
metadata={"help": "Number of train-time-test steps to use during training."},
)
eagle_base_lora: bool = field(
default=False,
metadata={
"help": (
"Whether to add LoRA adapters to the base model for co-training with the EAGLE "
"draft module. Requires the `peft` library. Incompatible with offline training."
)
},
)
eagle_base_lora_rank: int = field(
default=64,
metadata={"help": "LoRA rank for the base model adapters."},
)
eagle_base_lora_alpha: float = field(
default=16.0,
metadata={"help": "LoRA alpha (scaling) for the base model adapters."},
)
eagle_base_lora_target_modules: str = field(
default=None,
metadata={
"help": (
"Comma-separated list of module name patterns to apply LoRA to in the base model "
"(e.g. 'q_proj,v_proj'). Defaults to peft's default target modules."
)
},
)
eagle_base_lora_preservation_loss_weight: float = field(
default=1.0,
metadata={
"help": (
"Weight for the preservation loss that minimizes KL divergence between the "
"LoRA-adapted base model output and the original base model output."
)
},
)
eagle_base_lora_lr_multiplier: float = field(
default=1.0,
metadata={
"help": (
"Learning rate multiplier for LoRA parameters relative to the base learning rate."
)
},
)


def train():
Expand Down Expand Up @@ -217,13 +260,23 @@ def train():
json.load(open(eagle_args.eagle_config)) if eagle_args.eagle_config else {}
)

lora_target_modules = (
eagle_args.eagle_base_lora_target_modules.split(",")
if eagle_args.eagle_base_lora_target_modules
else None
)
config = {
"eagle_decoder_type": eagle_args.eagle_decoder_type,
"eagle_offline": use_offline_training,
"eagle_mix_hidden_states": eagle_args.mix_hidden_states,
"eagle_use_torch_compile": not eagle_args.disable_torch_compile,
"eagle_ttt_steps": eagle_args.num_ttt_steps,
"eagle_architecture_config": custom_config,
"eagle_base_lora": eagle_args.eagle_base_lora,
"eagle_base_lora_rank": eagle_args.eagle_base_lora_rank,
"eagle_base_lora_alpha": eagle_args.eagle_base_lora_alpha,
"eagle_base_lora_target_modules": lora_target_modules,
"eagle_base_lora_preservation_loss_weight": eagle_args.eagle_base_lora_preservation_loss_weight,
}

mtsp.convert(model, [("eagle", config)])
Expand All @@ -239,6 +292,20 @@ def train():
else:
raise Exception(f"{training_args.mode} is not supported!")

# Move any remaining CPU buffers to CUDA so DDP (NCCL-only) can broadcast
# them. We iterate named_buffers and reassign via the owning module to
# keep the module tree consistent. Parameters are left on CPU — the HF
# Trainer will move them during init.
if torch.cuda.is_available():
_target_dev = torch.device("cuda", 0)
for name, buf in list(model.named_buffers()):
if buf.device.type == "cpu":
parts = name.split(".")
mod = model
for p in parts[:-1]:
mod = getattr(mod, p)
setattr(mod, parts[-1], buf.to(_target_dev))

print_rank_0("Loading dataset...")
if training_args.mode == "eagle3":
data_module = make_eagle_supervised_data_module(
Expand All @@ -250,6 +317,7 @@ def train():
processing_class=tokenizer,
args=training_args,
callbacks=[EagleTrainingPlot(training_args.ar_validate_steps, training_args.estimate_ar)],
lora_lr_multiplier=eagle_args.eagle_base_lora_lr_multiplier,
**data_module,
)

Expand Down
1 change: 1 addition & 0 deletions examples/speculative_decoding/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
accelerate==1.12.0
peft==0.18.1
transformers==5.0.0rc1
Loading
Loading