add: DFlash block diffusion speculative decoding#1128
add: DFlash block diffusion speculative decoding#1128
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds DFlash speculative-decoding: new DFlash config and defaults, conversion/restore utilities and registry, a DFlash base model class and HuggingFace plugin implementing training/generation logic, integrates a new "dflash" mode and example/training adjustments. Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant BaseModel as Base LLM
participant HiddenCollector as Hidden State<br/>Collector
participant Fusion as Feature Fusion<br/>(FC + RMSNorm)
participant DraftModule as DFlash Draft<br/>Module (decoder stack)
participant LMHead as Logit Head
participant Loss as Loss & Accuracy
User->>BaseModel: input_ids + attention_mask
BaseModel->>HiddenCollector: forward -> hidden states
HiddenCollector->>Fusion: collect target-layer states
Fusion->>DraftModule: fused targets + noise embeddings
DraftModule->>LMHead: draft hidden -> logits
LMHead->>Loss: compute CE / accuracy
Loss-->>User: loss + accuracy
rect rgba(100,150,200,0.5)
Note over BaseModel,Loss: DFlash training forward pass
end
sequenceDiagram
actor User
participant BaseModel as Base LLM
participant DFlashMod as DFlash Module
participant BlockBuilder as Block Builder
participant DraftDecoder as Draft Decoder
participant TokenSel as Token Selector
User->>BaseModel: input_ids (context)
loop for each generation step
BaseModel->>DFlashMod: base next-token anchor + hidden states
DFlashMod->>DFlashMod: fuse target layers
loop for each block position
BlockBuilder->>DraftDecoder: build noise (anchor + mask), run draft step
DraftDecoder->>TokenSel: draft logits -> argmax
end
DraftDecoder-->>User: base token + draft block
end
rect rgba(150,200,100,0.5)
Note over BaseModel,TokenSel: DFlash pseudo-speculative generation
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (4)
modelopt/torch/speculative/dflash/dflash_model.py (1)
27-34: Add type hint forconfigparameter.The
configparameter lacks a type annotation. Per project standards, type hints should be provided for static type checking with mypy.♻️ Proposed fix
+from ..config import DFlashConfig + + class DFlashModel(DynamicModule): """Base DFlash Model.""" def _setup(self): self._register_temp_attribute("dflash_module", None) - def modify(self, config): + def modify(self, config: DFlashConfig): """Base DFlash Model modify function. Child class should implement the details."""🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/dflash/dflash_model.py` around lines 27 - 34, The modify method's config parameter is missing a type annotation; update def modify(self, config) to include the appropriate config type (e.g., def modify(self, config: DFlashConfig)) and import that type at the top of the module from wherever the project's config dataclass/typing lives; if a concrete config type isn't available yet, annotate with typing.Any as a temporary fallback and add the proper import for Any. Ensure you update imports (from typing import Any or from <module> import DFlashConfig) and keep the existing attribute assignments in modify unchanged.modelopt/torch/speculative/plugins/hf_dflash.py (3)
320-327: Silent exception handling may hide initialization errors.The broad
except Exception: continuepattern silently ignores all errors when locating base model parts, potentially masking genuine issues like attribute errors or type mismatches.♻️ Proposed improvement to catch only expected exceptions
for path in paths: try: submodule = self.get_submodule(path) assert isinstance(submodule, torch.nn.Module) setattr(self, name, path) break - except Exception: + except (AttributeError, AssertionError): continue🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 320 - 327, The loop that tries to locate submodules uses a broad "except Exception: continue", which can hide real errors; change it to only catch expected exceptions (e.g., except (AttributeError, AssertionError, TypeError): continue) so intentional lookup failures are ignored but other unexpected exceptions surface (or are logged/re-raised); keep the calls to self.get_submodule and setattr(self, name, path) intact but ensure you handle and/or log unexpected exceptions rather than swallowing them silently.
471-474: Complex anchor sampling logic is hard to follow.The nested
max(),min(), andrange()calls make this expression difficult to reason about. Consider breaking it into intermediate variables for clarity and easier debugging.♻️ Suggested refactor for readability
- num_blocks = max(1, max_anchor // block_size) - # Sample anchor positions uniformly - anchors = sorted( - random.sample(range(1, max(2, max_anchor)), min(num_blocks, max(1, max_anchor - 1))) - ) + num_blocks = max(1, max_anchor // block_size) + # Sample anchor positions uniformly + sample_range_end = max(2, max_anchor) + sample_count = min(num_blocks, max(1, max_anchor - 1)) + anchors = sorted(random.sample(range(1, sample_range_end), sample_count))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 471 - 474, The anchor sampling expression assigned to anchors is hard to read; break it into named intermediate variables (e.g., compute max_anchor_bound = max(2, max_anchor), sample_upper = max(1, max_anchor - 1), num_to_sample = min(num_blocks, sample_upper), and the range_to_sample = range(1, max_anchor_bound)) and then call random.sample(range_to_sample, num_to_sample) and sort the result; update the code in hf_dflash.py where anchors is defined (the anchors variable in the anchor sampling block) to use these intermediate names for clarity and easier debugging.
346-347: Accessing private_attn_implementationattribute is fragile.
_attn_implementationis a private attribute ofPretrainedConfigthat may change between transformers versions without notice.♻️ Suggested defensive check
- if self.dflash_config._attn_implementation is None: + if getattr(self.dflash_config, "_attn_implementation", None) is None: self.dflash_config._attn_implementation = "sdpa"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 346 - 347, Replace the fragile direct access to the private attribute self.dflash_config._attn_implementation with a defensive check using getattr/hasattr and set via setattr (e.g., current = getattr(self.dflash_config, "attn_implementation", None) or fallback = getattr(self.dflash_config, "_attn_implementation", None); if current is None: setattr(self.dflash_config, "attn_implementation", "sdpa") ), so the code uses public names when present and only falls back to the underscore name if necessary, ensuring compatibility across transformers versions.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 368-375: The modify() method currently registers forward hooks
each time without removing prior hooks, causing duplicated collections; fix by
tracking hook handles and removing old hooks before registering new ones: add a
persistent attribute (e.g., self._registered_forward_hooks = []) initialized in
the class (constructor), at the start of modify() iterate over
self._registered_forward_hooks calling handle.remove() and then clear the list,
then when registering forward hooks on layers (where you call
layer.register_forward_hook(self._collect_hidden_hook)) capture each returned
handle and append it to self._registered_forward_hooks; also ensure you reset
self._target_hidden_states (and optionally self._cached_masks) when re-modifying
to avoid stale state.
- Around line 55-62: The build_target_layer_ids function collapses to the same
index when num_target_layers <= 4 because start=1 and end=num_target_layers-3
produce span <= 0; change the logic to handle shallow models: if
num_target_layers <= 4 (or if num_sample_layers >= num_target_layers) return a
set of valid unique layer indices (e.g., evenly spaced or simply range(0,
num_target_layers) trimmed to num_sample_layers) and otherwise compute evenly
spaced indices across [0, num_target_layers-1]; update build_target_layer_ids to
clamp/adjust start/end and to limit num_sample_layers to at most
num_target_layers so indices remain unique and within bounds.
---
Nitpick comments:
In `@modelopt/torch/speculative/dflash/dflash_model.py`:
- Around line 27-34: The modify method's config parameter is missing a type
annotation; update def modify(self, config) to include the appropriate config
type (e.g., def modify(self, config: DFlashConfig)) and import that type at the
top of the module from wherever the project's config dataclass/typing lives; if
a concrete config type isn't available yet, annotate with typing.Any as a
temporary fallback and add the proper import for Any. Ensure you update imports
(from typing import Any or from <module> import DFlashConfig) and keep the
existing attribute assignments in modify unchanged.
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 320-327: The loop that tries to locate submodules uses a broad
"except Exception: continue", which can hide real errors; change it to only
catch expected exceptions (e.g., except (AttributeError, AssertionError,
TypeError): continue) so intentional lookup failures are ignored but other
unexpected exceptions surface (or are logged/re-raised); keep the calls to
self.get_submodule and setattr(self, name, path) intact but ensure you handle
and/or log unexpected exceptions rather than swallowing them silently.
- Around line 471-474: The anchor sampling expression assigned to anchors is
hard to read; break it into named intermediate variables (e.g., compute
max_anchor_bound = max(2, max_anchor), sample_upper = max(1, max_anchor - 1),
num_to_sample = min(num_blocks, sample_upper), and the range_to_sample =
range(1, max_anchor_bound)) and then call random.sample(range_to_sample,
num_to_sample) and sort the result; update the code in hf_dflash.py where
anchors is defined (the anchors variable in the anchor sampling block) to use
these intermediate names for clarity and easier debugging.
- Around line 346-347: Replace the fragile direct access to the private
attribute self.dflash_config._attn_implementation with a defensive check using
getattr/hasattr and set via setattr (e.g., current = getattr(self.dflash_config,
"attn_implementation", None) or fallback = getattr(self.dflash_config,
"_attn_implementation", None); if current is None: setattr(self.dflash_config,
"attn_implementation", "sdpa") ), so the code uses public names when present and
only falls back to the underscore name if necessary, ensuring compatibility
across transformers versions.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 99f4f5f4-7d81-4e34-b82f-d958548f8d6d
📒 Files selected for processing (9)
examples/speculative_decoding/main.pymodelopt/torch/speculative/config.pymodelopt/torch/speculative/dflash/__init__.pymodelopt/torch/speculative/dflash/conversion.pymodelopt/torch/speculative/dflash/default_config.pymodelopt/torch/speculative/dflash/dflash_model.pymodelopt/torch/speculative/mode.pymodelopt/torch/speculative/plugins/__init__.pymodelopt/torch/speculative/plugins/hf_dflash.py
| def build_target_layer_ids(num_target_layers, num_sample_layers): | ||
| """Select layers uniformly from the target model for feature extraction.""" | ||
| if num_sample_layers == 1: | ||
| return [num_target_layers // 2] | ||
| start = 1 | ||
| end = num_target_layers - 3 | ||
| span = end - start | ||
| return [round(start + (i * span) / (num_sample_layers - 1)) for i in range(num_sample_layers)] |
There was a problem hiding this comment.
Edge case: small num_target_layers values produce degenerate sampling.
When num_target_layers <= 4, end <= start (e.g., num_target_layers=4 → start=1, end=1, span=0). All returned layer indices collapse to the same value, defeating uniform sampling. Consider adding a guard or adjusting the formula for shallow target models.
🛡️ Proposed fix to handle edge case
def build_target_layer_ids(num_target_layers, num_sample_layers):
"""Select layers uniformly from the target model for feature extraction."""
if num_sample_layers == 1:
return [num_target_layers // 2]
+ if num_target_layers <= 4:
+ # For very shallow models, sample from all available layers
+ start, end = 0, num_target_layers - 1
+ else:
+ start = 1
+ end = num_target_layers - 3
- start = 1
- end = num_target_layers - 3
span = end - start
+ if span <= 0:
+ return [start] * num_sample_layers
return [round(start + (i * span) / (num_sample_layers - 1)) for i in range(num_sample_layers)]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 55 - 62, The
build_target_layer_ids function collapses to the same index when
num_target_layers <= 4 because start=1 and end=num_target_layers-3 produce span
<= 0; change the logic to handle shallow models: if num_target_layers <= 4 (or
if num_sample_layers >= num_target_layers) return a set of valid unique layer
indices (e.g., evenly spaced or simply range(0, num_target_layers) trimmed to
num_sample_layers) and otherwise compute evenly spaced indices across [0,
num_target_layers-1]; update build_target_layer_ids to clamp/adjust start/end
and to limit num_sample_layers to at most num_target_layers so indices remain
unique and within bounds.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1128 +/- ##
==========================================
+ Coverage 70.18% 70.25% +0.06%
==========================================
Files 230 234 +4
Lines 26080 26146 +66
==========================================
+ Hits 18304 18368 +64
- Misses 7776 7778 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (4)
modelopt/torch/speculative/plugins/hf_dflash.py (4)
307-309: Device placement assumes_base_model.layersexists and is non-empty.If the base model has a different structure (e.g., no
layersattribute or empty layers), this will raise anAttributeErrororIndexError.🛡️ More robust device detection
self.dflash_module = DFlashModule(self.dflash_config) - self.dflash_module.to(self._base_model.dtype).to( - next(self._base_model.layers[-1].parameters()).device - ) + # Get device from any base model parameter + base_device = next(self._base_model.parameters()).device + base_dtype = next(self._base_model.parameters()).dtype + self.dflash_module.to(base_dtype).to(base_device)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 307 - 309, The placement code for self.dflash_module assumes self._base_model.layers exists and has elements; instead, detect device robustly by checking for a non-empty self._base_model.layers and using next(self._base_model.layers[-1].parameters()).device only when present, otherwise fall back to using next(self._base_model.parameters()).device (or a CPU default if parameters are absent); apply the chosen device and the target dtype (self._base_model.dtype) to self.dflash_module so device/dtype setting works for models without a layers attribute or with empty layers.
82-82: Unused attributeis_causal.The
is_causalattribute is set toFalsebut never referenced. The value is hardcoded directly inscaled_dot_product_attentioncall at line 129.♻️ Suggested cleanup
self.num_kv_heads = config.num_key_value_heads self.scaling = self.head_dim**-0.5 - self.is_causal = False🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` at line 82, The instance attribute self.is_causal is defined but never used; either remove its assignment or wire it into the attention call — replace the hardcoded False in the scaled_dot_product_attention invocation with self.is_causal (or delete the self.is_causal assignment if you prefer no flag). Update the attribute in the class initializer where self.is_causal is set and the scaled_dot_product_attention call at the line that currently passes False so they are consistent.
412-414: Zero loss tensor withrequires_grad=Truewon't propagate gradients.When
active_logits.numel() == 0, the returned loss tensor is a constant0.0withrequires_grad=True. While this avoids errors during backward pass, it produces a gradient of zero for all parameters. Consider logging a warning when this edge case occurs.💡 Add warning for visibility
+ import logging + logger = logging.getLogger(__name__) + if active_logits.numel() > 0: loss = F.cross_entropy(active_logits, active_labels) with torch.no_grad(): preds = active_logits.argmax(dim=-1) accuracy = (preds == active_labels).float().mean().item() else: + logger.warning("No active positions for loss computation") loss = torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True) accuracy = 0.0🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 412 - 414, When active_logits.numel() == 0 the code returns loss = torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True) which yields zero gradients; update the branch handling this case to also emit a clear warning so the condition is visible in logs (e.g., log that active_logits is empty and loss is a zero tensor) and keep returning the tensor and accuracy as currently done. Locate the branch using active_logits.numel() == 0 (the else that sets loss and accuracy), add a warning via the module/logger used in this file (e.g., logger.warning or process_logger) mentioning the function/context (hf_dflash loss computation) and the tensor/device/dtype, and ensure no change to the returned tensor shape or requires_grad behavior.
285-286: Accessing private HuggingFace attribute_attn_implementation.
_attn_implementationis a private/internal attribute in HuggingFace configs that may change without notice. Consider using the public API or documenting this dependency.♻️ More defensive approach
- if self.dflash_config._attn_implementation is None: - self.dflash_config._attn_implementation = "eager" + if getattr(self.dflash_config, "_attn_implementation", None) is None: + self.dflash_config._attn_implementation = "eager"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 285 - 286, The code currently writes directly to the private HuggingFace attribute self.dflash_config._attn_implementation; instead, make this defensive by checking for the attribute's existence with getattr/hasattr and only set it when present, otherwise fall back to a documented public config or internal default and log/warn about relying on a private attribute; update the assignment around self.dflash_config._attn_implementation to use a safe check (getattr(self.dflash_config, "_attn_implementation", None)) and a clear fallback path so the plugin won't break if the private field is removed in future.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 47-56: The function build_target_layer_ids should avoid the
redundant int(round(...)) and handle shallow models where span <= 0; change the
rounding to just round(...) (or remove the int() wrapper) and add a guard when
num_target_layers <= 4 (or when end <= start) to produce sensible unique
indices: compute end = max(1, num_target_layers - 3), compute span = max(1, end
- start), and if num_draft_layers > (end - start + 1) clamp num_draft_layers to
that maximum so the list comprehension in build_target_layer_ids returns evenly
spaced, non-duplicated layer indices (use the existing variables
num_target_layers, num_draft_layers, start, end, span in the fix).
---
Nitpick comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 307-309: The placement code for self.dflash_module assumes
self._base_model.layers exists and has elements; instead, detect device robustly
by checking for a non-empty self._base_model.layers and using
next(self._base_model.layers[-1].parameters()).device only when present,
otherwise fall back to using next(self._base_model.parameters()).device (or a
CPU default if parameters are absent); apply the chosen device and the target
dtype (self._base_model.dtype) to self.dflash_module so device/dtype setting
works for models without a layers attribute or with empty layers.
- Line 82: The instance attribute self.is_causal is defined but never used;
either remove its assignment or wire it into the attention call — replace the
hardcoded False in the scaled_dot_product_attention invocation with
self.is_causal (or delete the self.is_causal assignment if you prefer no flag).
Update the attribute in the class initializer where self.is_causal is set and
the scaled_dot_product_attention call at the line that currently passes False so
they are consistent.
- Around line 412-414: When active_logits.numel() == 0 the code returns loss =
torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True) which yields
zero gradients; update the branch handling this case to also emit a clear
warning so the condition is visible in logs (e.g., log that active_logits is
empty and loss is a zero tensor) and keep returning the tensor and accuracy as
currently done. Locate the branch using active_logits.numel() == 0 (the else
that sets loss and accuracy), add a warning via the module/logger used in this
file (e.g., logger.warning or process_logger) mentioning the function/context
(hf_dflash loss computation) and the tensor/device/dtype, and ensure no change
to the returned tensor shape or requires_grad behavior.
- Around line 285-286: The code currently writes directly to the private
HuggingFace attribute self.dflash_config._attn_implementation; instead, make
this defensive by checking for the attribute's existence with getattr/hasattr
and only set it when present, otherwise fall back to a documented public config
or internal default and log/warn about relying on a private attribute; update
the assignment around self.dflash_config._attn_implementation to use a safe
check (getattr(self.dflash_config, "_attn_implementation", None)) and a clear
fallback path so the plugin won't break if the private field is removed in
future.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 3311c2cf-da29-4e8f-bbc0-c97039583962
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/hf_dflash.py
| def build_target_layer_ids(num_target_layers, num_draft_layers): | ||
| """Select layers uniformly from the target model for feature extraction.""" | ||
| if num_draft_layers == 1: | ||
| return [num_target_layers // 2] | ||
| start = 1 | ||
| end = num_target_layers - 3 | ||
| span = end - start | ||
| return [ | ||
| int(round(start + (i * span) / (num_draft_layers - 1))) for i in range(num_draft_layers) | ||
| ] |
There was a problem hiding this comment.
Fix linting error and handle edge case for shallow target models.
Two issues:
-
Pipeline failure (Line 55):
round()already returns anintin Python 3, making theint()wrapper redundant. -
Edge case: When
num_target_layers <= 4,spanbecomes ≤ 0 (e.g.,num_target_layers=4→start=1,end=1,span=0), causing all returned indices to collapse to the same value.
🐛 Proposed fix
def build_target_layer_ids(num_target_layers, num_draft_layers):
"""Select layers uniformly from the target model for feature extraction."""
if num_draft_layers == 1:
return [num_target_layers // 2]
- start = 1
- end = num_target_layers - 3
+ # For shallow models, use full range; otherwise skip first and last few layers
+ if num_target_layers <= 4:
+ start, end = 0, num_target_layers - 1
+ else:
+ start, end = 1, num_target_layers - 3
span = end - start
+ if span <= 0 or num_draft_layers > num_target_layers:
+ # Fallback: return middle layer repeated
+ return [num_target_layers // 2] * num_draft_layers
return [
- int(round(start + (i * span) / (num_draft_layers - 1))) for i in range(num_draft_layers)
+ round(start + (i * span) / (num_draft_layers - 1)) for i in range(num_draft_layers)
]🧰 Tools
🪛 GitHub Actions: Code Quality
[error] 55-55: ruff check failed (RUF046). Value being cast to int is already an integer. Line suggests int(round(...)); help: remove unnecessary int call.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 47 - 56, The
function build_target_layer_ids should avoid the redundant int(round(...)) and
handle shallow models where span <= 0; change the rounding to just round(...)
(or remove the int() wrapper) and add a guard when num_target_layers <= 4 (or
when end <= start) to produce sensible unique indices: compute end = max(1,
num_target_layers - 3), compute span = max(1, end - start), and if
num_draft_layers > (end - start + 1) clamp num_draft_layers to that maximum so
the list comprehension in build_target_layer_ids returns evenly spaced,
non-duplicated layer indices (use the existing variables num_target_layers,
num_draft_layers, start, end, span in the fix).
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (3)
modelopt/torch/speculative/plugins/hf_dflash.py (3)
364-364: Replaceprint()with logging.Production code should use the logging module for better control over output verbosity.
♻️ Proposed fix
Add import at top of file:
import logging logger = logging.getLogger(__name__)Then replace the print:
- print(f"DFlash: using {original_cls.__name__}.forward as base forward") + logger.info(f"DFlash: using {original_cls.__name__}.forward as base forward")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` at line 364, Replace the stray print call with structured logging: add an import for the logging module and initialize a module-level logger (e.g., logger = logging.getLogger(__name__)), then replace the print(f"DFlash: using {original_cls.__name__}.forward as base forward") in hf_dflash.py with an appropriate logger call (logger.info or logger.debug) referencing original_cls.__name__ to retain the same message content; ensure the logger is defined at top of the file so the statement in the speculative plugin uses it.
260-269: Overly broad exception handling.Catching
Exceptionmay mask unexpected errors (e.g.,AttributeError,TypeError). Consider narrowing to the specific exceptions expected fromget_submodule.♻️ Proposed fix
for path in paths: try: submodule = self.get_submodule(path) assert isinstance(submodule, torch.nn.Module) setattr(self, name, path) break - except Exception: + except (AttributeError, AssertionError): continue🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 260 - 269, Narrow the broad except in the for-loop that calls get_submodule: replace "except Exception" with a targeted exception tuple such as "except (AttributeError, KeyError):" (these are the likely errors when a submodule path is not found) so only lookup-related failures are silenced while letting other errors (TypeError, AssertionError, etc.) surface; keep the rest of the logic in the loop (get_submodule, isinstance check, setattr(self, name, path) and the final ValueError) unchanged.
82-82: Unused attributeis_causal.
self.is_causalis assigned but never referenced. Line 129 hardcodesis_causal=Falsedirectly.♻️ Proposed fix
- self.is_causal = False🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` at line 82, The attribute self.is_causal is assigned but never used; replace the hardcoded is_causal=False occurrence in this module with the instance attribute so the class-level flag is honored: initialize self.is_causal in the class constructor (as currently present) and change the call/site that currently passes is_causal=False to pass is_causal=self.is_causal instead; alternatively, if causality is never intended to be configurable, remove the unused self.is_causal assignment and keep the hardcoded False—prefer the first option to preserve configurability.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 510-521: The debug print block guarded by self._psg_debug should
be removed or replaced with proper logging and a safe batch-aware access; locate
the block around _psg_debug that inspects base_outputs.hidden_states,
base_token, mask_token_id, dflash_block_size and calls
self._base_model_embeddings and either delete it or change prints to
logger.debug(...) and replace the unsafe base_token.item() with a batch-safe
access such as base_token[0,0].item() so it won't fail when bsz > 1; keep the
one-time flag behavior if you want this to run only once.
---
Nitpick comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Line 364: Replace the stray print call with structured logging: add an import
for the logging module and initialize a module-level logger (e.g., logger =
logging.getLogger(__name__)), then replace the print(f"DFlash: using
{original_cls.__name__}.forward as base forward") in hf_dflash.py with an
appropriate logger call (logger.info or logger.debug) referencing
original_cls.__name__ to retain the same message content; ensure the logger is
defined at top of the file so the statement in the speculative plugin uses it.
- Around line 260-269: Narrow the broad except in the for-loop that calls
get_submodule: replace "except Exception" with a targeted exception tuple such
as "except (AttributeError, KeyError):" (these are the likely errors when a
submodule path is not found) so only lookup-related failures are silenced while
letting other errors (TypeError, AssertionError, etc.) surface; keep the rest of
the logic in the loop (get_submodule, isinstance check, setattr(self, name,
path) and the final ValueError) unchanged.
- Line 82: The attribute self.is_causal is assigned but never used; replace the
hardcoded is_causal=False occurrence in this module with the instance attribute
so the class-level flag is honored: initialize self.is_causal in the class
constructor (as currently present) and change the call/site that currently
passes is_causal=False to pass is_causal=self.is_causal instead; alternatively,
if causality is never intended to be configurable, remove the unused
self.is_causal assignment and keep the hardcoded False—prefer the first option
to preserve configurability.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 90e791a0-d745-4cf9-8a9d-1600e8d1ad34
📒 Files selected for processing (3)
examples/speculative_decoding/main.pymodelopt/torch/speculative/dflash/default_config.pymodelopt/torch/speculative/plugins/hf_dflash.py
✅ Files skipped from review due to trivial changes (1)
- modelopt/torch/speculative/dflash/default_config.py
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/speculative_decoding/main.py
| if not hasattr(self, '_psg_debug'): | ||
| self._psg_debug = True | ||
| sel = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] | ||
| th = torch.cat(sel, dim=-1) | ||
| print(f"[psg] hidden_states layers: {len(base_outputs.hidden_states)}, target_hidden norm: {th.norm().item():.2f}, shape: {th.shape}") | ||
| print(f"[psg] base_token: {base_token.item()}, mask_token_id: {self.mask_token_id}") | ||
| print(f"[psg] block_ids: {[self.mask_token_id]*self.dflash_block_size}") | ||
| bi = torch.full((1, self.dflash_block_size), self.mask_token_id, dtype=torch.long, device=input_ids.device) | ||
| bi[0, 0] = base_token[0, 0] | ||
| ne = self._base_model_embeddings(bi) | ||
| print(f"[psg] noise_emb norm: {ne.norm().item():.2f}, shape: {ne.shape}") | ||
| print(f"[psg] pos_ids will be: ctx=[0..{input_ids.shape[1]-1}], blk=[{input_ids.shape[1]}..{input_ids.shape[1]+self.dflash_block_size-1}]") |
There was a problem hiding this comment.
Remove debug print statements from production code.
These debug prints are guarded by _psg_debug but will still execute once per model instance. Additionally, line 515 uses base_token.item() which will fail when bsz > 1.
🐛 Proposed fix
Either remove the debug block entirely, or convert to proper logging:
+import logging
+logger = logging.getLogger(__name__)
+
...
- if not hasattr(self, '_psg_debug'):
- self._psg_debug = True
- sel = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids]
- th = torch.cat(sel, dim=-1)
- print(f"[psg] hidden_states layers: {len(base_outputs.hidden_states)}, target_hidden norm: {th.norm().item():.2f}, shape: {th.shape}")
- print(f"[psg] base_token: {base_token.item()}, mask_token_id: {self.mask_token_id}")
- print(f"[psg] block_ids: {[self.mask_token_id]*self.dflash_block_size}")
- bi = torch.full((1, self.dflash_block_size), self.mask_token_id, dtype=torch.long, device=input_ids.device)
- bi[0, 0] = base_token[0, 0]
- ne = self._base_model_embeddings(bi)
- print(f"[psg] noise_emb norm: {ne.norm().item():.2f}, shape: {ne.shape}")
- print(f"[psg] pos_ids will be: ctx=[0..{input_ids.shape[1]-1}], blk=[{input_ids.shape[1]}..{input_ids.shape[1]+self.dflash_block_size-1}]")If debug logging is needed, use logger.debug() and fix the batch-size issue:
logger.debug(f"[psg] base_token: {base_token[0, 0].item()}, mask_token_id: {self.mask_token_id}")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if not hasattr(self, '_psg_debug'): | |
| self._psg_debug = True | |
| sel = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] | |
| th = torch.cat(sel, dim=-1) | |
| print(f"[psg] hidden_states layers: {len(base_outputs.hidden_states)}, target_hidden norm: {th.norm().item():.2f}, shape: {th.shape}") | |
| print(f"[psg] base_token: {base_token.item()}, mask_token_id: {self.mask_token_id}") | |
| print(f"[psg] block_ids: {[self.mask_token_id]*self.dflash_block_size}") | |
| bi = torch.full((1, self.dflash_block_size), self.mask_token_id, dtype=torch.long, device=input_ids.device) | |
| bi[0, 0] = base_token[0, 0] | |
| ne = self._base_model_embeddings(bi) | |
| print(f"[psg] noise_emb norm: {ne.norm().item():.2f}, shape: {ne.shape}") | |
| print(f"[psg] pos_ids will be: ctx=[0..{input_ids.shape[1]-1}], blk=[{input_ids.shape[1]}..{input_ids.shape[1]+self.dflash_block_size-1}]") |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 510 - 521, The
debug print block guarded by self._psg_debug should be removed or replaced with
proper logging and a safe batch-aware access; locate the block around _psg_debug
that inspects base_outputs.hidden_states, base_token, mask_token_id,
dflash_block_size and calls self._base_model_embeddings and either delete it or
change prints to logger.debug(...) and replace the unsafe base_token.item() with
a batch-safe access such as base_token[0,0].item() so it won't fail when bsz >
1; keep the one-time flag behavior if you want this to run only once.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
examples/speculative_decoding/eagle_utils.py (1)
241-249: Cache the MT-Bench dataset outside the synchronized section.
load_dataset(...)now sits on the global stall path, so every worker waits for rank 0 to redo the hub/cache lookup before training resumes. Reusing a cached dataset here would cut that idle time.♻️ Proposed refactor
class EagleTrainingPlot(TrainerCallback): def __init__(self, ar_validate_steps: int = 1000, estimate_ar: bool = False): self.ar_validate_steps = ar_validate_steps if wandb and is_master(): wandb.init() self.estimate_ar = estimate_ar + self._ar_validation_ds = None @@ if is_master(): print_rank_0("Running AR validation...") try: + if self._ar_validation_ds is None: + self._ar_validation_ds = load_dataset("HuggingFaceH4/mt_bench_prompts")[ + "train" + ] ars = validate_ar( model=kwargs["model"], tokenizer=kwargs["processing_class"], - ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"], + ds=self._ar_validation_ds, device=kwargs["model"].device, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/eagle_utils.py` around lines 241 - 249, The call to load_dataset("HuggingFaceH4/mt_bench_prompts") is happening inside the synchronized/critical section causing all ranks to wait; move the dataset load/cache out of that section and reuse a single cached dataset reference when calling validate_ar (keep calling validate_ar(model=kwargs["model"], tokenizer=kwargs["processing_class"], ds=cached_mt_bench_ds, device=kwargs["model"].device)), ensuring cached_mt_bench_ds is initialized once (e.g., at module import or rank-0 setup) and shared by workers so only the heavy hub/cache lookup happens once while still preserving use of state.global_step, ars, and wandb logging.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/speculative_decoding/eagle_utils.py`:
- Around line 238-254: Load the validation dataset before the master-only block
so non-master ranks don't waste time; specifically, when checking the AR
validation trigger (state.global_step % self.ar_validate_steps == 0 and
state.global_step > 0) call
load_dataset("HuggingFaceH4/mt_bench_prompts")["train"] into a local variable
(e.g., ds) before the is_master() check, then inside the is_master() block call
validate_ar(model=kwargs["model"], tokenizer=kwargs["processing_class"], ds=ds,
device=kwargs["model"].device) as before; keep print_rank_0, the try/except
around validate_ar, and the torch.distributed.barrier() after the block to
preserve synchronization.
---
Nitpick comments:
In `@examples/speculative_decoding/eagle_utils.py`:
- Around line 241-249: The call to
load_dataset("HuggingFaceH4/mt_bench_prompts") is happening inside the
synchronized/critical section causing all ranks to wait; move the dataset
load/cache out of that section and reuse a single cached dataset reference when
calling validate_ar (keep calling validate_ar(model=kwargs["model"],
tokenizer=kwargs["processing_class"], ds=cached_mt_bench_ds,
device=kwargs["model"].device)), ensuring cached_mt_bench_ds is initialized once
(e.g., at module import or rank-0 setup) and shared by workers so only the heavy
hub/cache lookup happens once while still preserving use of state.global_step,
ars, and wandb logging.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 43b9f242-1583-40b4-8db9-7a3aeb9b07cd
📒 Files selected for processing (1)
examples/speculative_decoding/eagle_utils.py
| if is_master(): | ||
| print_rank_0("Running AR validation...") | ||
| try: | ||
| ars = validate_ar( | ||
| model=kwargs["model"], | ||
| tokenizer=kwargs["processing_class"], | ||
| ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"], | ||
| device=kwargs["model"].device, | ||
| ) | ||
| print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}") | ||
| if wandb: | ||
| wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step) | ||
| except Exception: | ||
| print_rank_0("AR validation not available.") | ||
| # Barrier to synchronize all ranks after validation | ||
| if torch.distributed.is_initialized(): | ||
| torch.distributed.barrier() |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
AR_VALIDATE_FILE="$(fd -a '^ar_validate\.py$' scripts | head -n 1)"
test -n "$AR_VALIDATE_FILE"
echo "== validate_ar implementation: $AR_VALIDATE_FILE =="
sed -n '1,260p' "$AR_VALIDATE_FILE"
echo
echo "== distributed/sharding-sensitive calls in validate_ar =="
rg -n -C2 'unwrap|model\.module|distributed|barrier|all_reduce|all_gather|get_rank|fsdp|deepspeed|context_parallel|tensor_parallel|pseudo_speculative_generate|generate\(' "$AR_VALIDATE_FILE" || true
echo
echo "== Eagle/distributed call sites =="
rg -n -C2 'EagleTrainingPlot|patch_ring_attention_for_ttt|cp_size|tp_size|context parallel|tensor parallel' examples/speculative_decoding || trueRepository: NVIDIA/Model-Optimizer
Length of output: 160
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# First, explore top-level structure
echo "=== Repository structure ==="
ls -la | head -20
echo
echo "=== Looking for ar_validate or validate_ar files ==="
fd -i 'ar_validate|validate.*ar' --type f
echo
echo "=== Looking for scripts directory ==="
find . -maxdepth 2 -type d -name 'scripts' 2>/dev/null || echo "No scripts dir found"
echo
echo "=== Searching for validate_ar function definition ==="
rg -l 'def validate_ar' --type pyRepository: NVIDIA/Model-Optimizer
Length of output: 1578
🏁 Script executed:
#!/bin/bash
set -euo pipefail
VALIDATE_FILE="examples/speculative_decoding/scripts/ar_validate.py"
# Read the file to understand validate_ar implementation
echo "=== Full ar_validate.py content ==="
cat -n "$VALIDATE_FILE"Repository: NVIDIA/Model-Optimizer
Length of output: 4763
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Searching for HFARValidation class ==="
rg -n 'class HFARValidation|def validate\(' --type py -B2 -A10 | head -60
echo
echo "=== Checking eagle_utils.py context around validation ==="
cat -n examples/speculative_decoding/eagle_utils.py | sed -n '220,260p'
echo
echo "=== Searching for model wrapping/preparation in eagle_utils.py ==="
rg -n 'unwrap|DistributedDataParallel|FSDP|tensor_parallel|context_parallel' examples/speculative_decoding/eagle_utils.py || echo "No explicit wrapping patterns found"Repository: NVIDIA/Model-Optimizer
Length of output: 7010
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Searching for pseudo_speculative_generate definition ==="
rg -n 'def pseudo_speculative_generate' --type py -B2 -A15 | head -50
echo
echo "=== Checking eagle training script for model setup ==="
fd -i 'eagle.*train' --type py examples/speculative_decoding | head -5
echo
echo "=== Checking imports in eagle_utils.py ==="
head -50 examples/speculative_decoding/eagle_utils.pyRepository: NVIDIA/Model-Optimizer
Length of output: 4716
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Full pseudo_speculative_generate from transformers.py ==="
rg -n 'def pseudo_speculative_generate' modelopt/torch/speculative/plugins/transformers.py -A30
echo
echo "=== Eagle training script ==="
fd -type f -name '*eagle*train*' examples/speculative_decoding
echo
echo "=== First 80 lines of eagle_utils.py imports ==="
head -80 examples/speculative_decoding/eagle_utils.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1651
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Full pseudo_speculative_generate from transformers.py ==="
sed -n '1055,1100p' modelopt/torch/speculative/plugins/transformers.py
echo
echo "=== Eagle training scripts ==="
find examples/speculative_decoding -name '*train*' -type f 2>/dev/null
echo
echo "=== First 80 lines of eagle_utils.py ==="
head -80 examples/speculative_decoding/eagle_utils.py
echo
echo "=== Check if model gets prepared with accelerate/DDP ==="
rg -n 'prepare|DistributedDataParallel|FSDP' examples/speculative_decoding --type py | head -30Repository: NVIDIA/Model-Optimizer
Length of output: 4722
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Check pop_and_gather_aux_hiddens for collective operations ==="
rg -n 'pop_and_gather_aux_hiddens|all_gather|broadcast|reduce' modelopt/torch/speculative/plugins/transformers.py -B2 -A8 | head -80
echo
echo "=== Check if EAGLE model has distributed state ==="
rg -n 'class.*EAGLE|_aux_hidden_states|gather' modelopt/torch/speculative/plugins/transformers.py | head -30
echo
echo "=== Check training script for model preparation ==="
grep -r 'Trainer\|prepare\|accelerate' examples/speculative_decoding/train*.sh 2>/dev/null || echo "No train scripts with prepare"Repository: NVIDIA/Model-Optimizer
Length of output: 4128
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Check if pop_and_gather_aux_hiddens has ANY distributed calls ==="
sed -n '532,560p' modelopt/torch/speculative/plugins/transformers.py
echo
echo "=== Check eagle_utils.py for context-parallel or tensor-parallel wrapping ==="
rg -n 'patch_ring_attention|cp_size|tp_size|tensor_parallel|context_parallel' examples/speculative_decoding/eagle_utils.py -B3 -A3
echo
echo "=== Check if model is actually wrapped in training callback context ==="
rg -n 'EagleTrainingPlot|on_step_end|kwargs\[' examples/speculative_decoding/eagle_utils.py | head -20
echo
echo "=== Look for where model might be context-parallel wrapped ==="
grep -r 'patch_ring_attention_for_ttt\|apply.*context' examples/speculative_decoding/ --include='*.py' | head -10Repository: NVIDIA/Model-Optimizer
Length of output: 3447
Move dataset loading outside the master-only validation block for efficiency.
The validation code is safe as-is: validate_ar() contains no distributed collectives, and pop_and_gather_aux_hiddens() only performs local tensor concatenation (not a distributed gather). The barrier() correctly synchronizes ranks after rank-0 validation.
However, reloading the validation dataset inside the master-only block wastes idle time on non-master ranks. Load the dataset once before the validation check:
if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0:
ds = load_dataset("HuggingFaceH4/mt_bench_prompts")["train"]
if is_master():
print_rank_0("Running AR validation...")
try:
ars = validate_ar(
model=kwargs["model"],
tokenizer=kwargs["processing_class"],
ds=ds,
device=kwargs["model"].device,
)
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
if wandb:
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
except Exception:
print_rank_0("AR validation not available.")
if torch.distributed.is_initialized():
torch.distributed.barrier()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/speculative_decoding/eagle_utils.py` around lines 238 - 254, Load
the validation dataset before the master-only block so non-master ranks don't
waste time; specifically, when checking the AR validation trigger
(state.global_step % self.ar_validate_steps == 0 and state.global_step > 0) call
load_dataset("HuggingFaceH4/mt_bench_prompts")["train"] into a local variable
(e.g., ds) before the is_master() check, then inside the is_master() block call
validate_ar(model=kwargs["model"], tokenizer=kwargs["processing_class"], ds=ds,
device=kwargs["model"].device) as before; keep print_rank_0, the try/except
around validate_ar, and the torch.distributed.barrier() after the block to
preserve synchronization.
|
LGTM. Shall we also add some unit/example tests in this PR |
Implement DFlash (Block Diffusion for Flash Speculative Decoding) as a new mode in ModelOpt's speculative decoding framework. Key architecture: - Feature Fusion: extract hidden states from uniformly sampled target model layers, project via FC layer - KV Injection: fused target features injected as K/V entries in every draft decoder layer's attention (not just first layer input) - Parallel Drafting: all tokens in a block predicted simultaneously using learnable mask embeddings and bidirectional within-block attention Files: - dflash/ module: DFlashModel, DFlashConfig, conversion, default config - plugins/hf_dflash.py: HFDFlashModel with DFlashAttention (KV injection), DFlashModule (feature fusion + decoder), training forward pass with random anchor sampling and exponential position decay loss - main.py: --mode dflash support in training script Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Key fixes: - mask_token_id now read from dflash_architecture_config (e.g., 248070 for Qwen3) instead of defaulting to pad/eos token. Wrong mask_token_id caused garbage draft output despite correct weights. - Inherit model config from base model only as defaults; allow draft to have different num_heads/intermediate_size (needed for z-lab checkpoint) - Clean default_dflash_config to only contain DFlash-specific settings - pseudo_speculative_generate returns single block of tokens - Add dflash_mask_token_id CLI argument to main.py Validated: z-lab/Qwen3.5-4B-DFlash checkpoint produces AR=7.28 (expected ~6.08) Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Resolution order: 1. Explicit in dflash_architecture_config (user override) 2. Auto-detect from model vocabulary: - Qwen3/3.5: built-in [MASK] token (e.g., 248070) - Llama3: reserved_special_token_0 (128002) - Others: pad_token_id fallback 3. CLI override via --dflash_mask_token_id Based on z-lab checkpoints: - z-lab/Qwen3.5-4B-DFlash: mask=248070 - z-lab/LLaMA3.1-8B-Instruct-DFlash: mask=128002 - z-lab/gpt-oss-20b-DFlash: mask=200000 Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
AR validation runs pseudo_speculative_generate which does unsynchronized model forward passes. In multi-GPU DDP training, this caused NCCL timeout because other ranks were waiting at gradient sync. Fix: only run validate_ar on rank 0 (is_master()), add torch.distributed.barrier() after to synchronize all ranks. Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
super().forward() from HFDFlashModel goes through DynamicModule which dispatches back to HFDFlashModel.forward(), causing infinite recursion → stack overflow → NCCL timeout in multi-GPU training. Fix: use self._base_model() directly (same as pseudo_speculative_generate) for both eval-mode and training base model forward passes. Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
The DynamicModule MRO correctly dispatches super().forward() to the original model class (e.g., Qwen3_5ForCausalLM.forward()) without looping — same pattern EAGLE uses successfully. The previous self._base_model() approach bypassed DDP, causing NCCL timeout because DDP's gradient sync couldn't track the forward pass. Keep pseudo_speculative_generate using self._base_model() since that runs outside DDP (single GPU AR validation). Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
When a rank's batch has no valid loss positions (e.g., all tokens in Block 0 which is excluded), the loss was a detached zero tensor with no connection to dflash_module parameters. DDP waited forever for gradient sync on those parameters → NCCL ALLREDUCE timeout. Fix: use logits.sum() * 0.0 as zero loss, which maintains the computation graph through dflash_module parameters so DDP can sync zero gradients properly. Also revert to super().forward() for training (matching EAGLE pattern) and add --ddp_find_unused_parameters True, --ddp_timeout 300. Root cause analysis: rank 4 completed ALLREDUCE #272 and proceeded to ALLGATHER #273, while other ranks were stuck at ALLREDUCE #272. This indicated rank 4 had a different backward graph (no gradients for dflash_module on that rank). Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Add --dflash_use_logit_distillation flag that switches from hard CE loss (predict ground truth tokens) to logit distillation (learn from target model's output distribution). Hard CE only works when training data is synthesized by the target model itself. Logit distillation works with any data because it learns from the target model's actual predictions, not the ground truth. Usage: python main.py --mode dflash --dflash_use_logit_distillation ... Config: dflash_self_logit_distillation (default=True in config, toggled via CLI flag) Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Pass answer_only_loss=True to LanguageDataCollator for DFlash mode. This makes the tokenizer return assistant_masks via apply_chat_template with return_assistant_tokens_mask=True. HFDFlashModel.forward() now checks for assistant_masks in kwargs and uses it as loss_mask instead of attention_mask. This matches SpecForge's behavior of only computing loss on response tokens. SpecForge-trained checkpoint (response-only mask): AR=1.95 ModelOpt-trained checkpoint (all tokens mask): AR=1.15 Both with 30-35% training accuracy on same data. Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
When answer_only_loss=True, set labels=-100 for non-assistant tokens using the assistant_masks from tokenizer.apply_chat_template. This ensures DFlash forward() can derive response-only loss mask from labels != -100, without relying on HF Trainer to pass assistant_masks. Also revert hf_dflash.py to use labels-based loss mask instead of kwargs-based assistant_masks (Trainer strips unknown keys). Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
When answer_only_loss=True and the tokenizer's return_assistant_tokens_mask returns empty/unsupported results, fall back to regex-based detection of assistant spans in the formatted text (similar to SpecForge's approach). Supports Qwen/ChatML, Llama3, Llama2, and generic assistant patterns. Uses tokenizer offset_mapping to map character spans to token positions. DFlash forward uses labels != -100 to derive the response-only loss mask. Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Documents DFlash architecture, training usage, mask_token_id auto-detection, and current status including the known AR gap from data pipeline differences. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Instead of hardcoding Llama components (LlamaMLP, LlamaRMSNorm, LlamaRotaryEmbedding), dynamically resolve them from the base model's transformers module (e.g., Qwen3MLP for Qwen3 models). Falls back to Llama components for unknown model types. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Two bugs prevented response-only masking from working: 1. main.py never passed answer_only_loss=True to the data collator for DFlash mode, so all tokens had labels (511/512 instead of response-only). 2. HFDFlashModel.forward() used attention_mask (padding mask) for loss masking instead of labels. When answer_only_loss is enabled, the response-only information is in labels (where -100 = ignore), but this was completely ignored. Now uses labels when available. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
- Add common/dflash/online_training.sh for launcher - Add examples/Qwen/Qwen3-8B/hf_online_dflash.yaml - Add --mode dflash support to launch_train.sh with DFlash-specific args (block_size, num_layers, mask_token_id, config) - DFlash uses DDP instead of FSDP for training Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
global_vars keys conflict with nemo_run's CLI parser when using --yaml format. Inline the values directly instead. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Tests cover: - Model conversion (creates HFDFlashModel, DFlashModule, freezes base, sets target_layer_ids and mask_token_id) - Save/restore via HuggingFace checkpointing - Attention mask (shape, strictly-previous-block context, causal noise) - Loss mask (excludes block 0 and block starts, correct count) - Draft module forward (output shape, determinism) - Training forward (loss, accuracy, labels masking, all-masked edge case, gradient flow, eval mode fallback) - Target layer ID selection (single/multiple layers, spread, bounds) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
d9cba02 to
c4a3ecb
Compare
- Add common/dflash/ar_validate.sh for acceptance rate evaluation - Insert as task_1 in hf_online_dflash.yaml (train → AR → benchmark) - Uses single GPU, loads trained checkpoint from /scratchspace/dflash - Evaluates on MT-Bench prompts with pseudo_speculative_generate Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Move GPU-dependent tests (training forward, module forward) from tests/unit/ to tests/gpu/torch/speculative/plugins/. CPU-only tests (masks, layer IDs, convert, save/restore) remain in tests/unit/. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
DFlash uses reverse-causal within blocks (matching SpecForge): earlier positions see more noise keys, later positions see fewer. This is intentional for block diffusion denoising. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
cjluo-nv
left a comment
There was a problem hiding this comment.
Summary: Adds DFlash block diffusion speculative decoding as a new mode in ModelOpt. Introduces a new dflash/ module with config, conversion, and model classes, an HF plugin (hf_dflash.py) implementing training/generation, CLI integration, and both unit and GPU tests.
Issues Found:
-
[Correctness] Global mutable state in
hf_dflash.py(lines 89-90, 371-375): The module-level globals_MLP_CLS,_NORM_CLS,_ROTARY_CLS,_rotate_halfare reassigned insideHFDFlashModel.modify(). This is not thread-safe and breaks if two models with differentmodel_typeare converted in the same process. TheDFlashAttention.__init__andDFlashDecoderLayer.__init__capture these globals at construction time, butapply_rotary_pos_embcaptures_rotate_halfat module level and would use whichever was set last. Consider passing these as constructor arguments or storing them on the config/module instance. -
[Correctness]
_process_chat_sampleoverride ignores existing tokenizer-based masking (transformers_dataset.pylines 189-199): The existing_process_chat_samplealready passesreturn_assistant_tokens_mask=self.answer_only_losstoapply_chat_template(line 183). The PR adds ~70 lines of regex-based fallback logic (_apply_answer_only_labels) but never uses theassistant_masksthat the tokenizer already returns. The checkif "assistant_masks" in tokenized_exampleson the new line 191 references a key that the tokenizer returns asassistant_masksonly when the feature is supported — butapply_chat_templatewithreturn_assistant_tokens_mask=Truereturns it in the dict. The code path seems confused: if the tokenizer supportsreturn_assistant_tokens_mask, the mask is already available and should be applied directly. The regex fallback with hardcoded patterns (<|im_start|>assistant,<|start_header_id|>assistant) is fragile and duplicates tokenizer functionality. -
[Correctness]
main.pyline 266:json.load(open(...))— unclosed file handle: Usewith open(...)orPath(...).read_text()+json.loads(). -
[Correctness] Mask token auto-detection heuristics (
hf_dflash.pylines 290-326) are brittle: The offsets (26, 25, 24) for Qwen3.5,vocab_size - 250for smaller Qwen, and hardcoded128002for Llama3 are magic numbers that will silently produce wrong tokens on new model versions. A wrong mask token will cause silent training failures. At minimum, add a validation warning when falling through to heuristic paths. -
[Correctness]
hf_dflash.pyline 456 —super().forward()in training callsDynamicModule.forwardwhich dispatches to the current class'sforward(), potentially causing infinite recursion depending on MRO. The eval path (line 430) also callssuper().forward(). This relies onDynamicModuledispatching to_original_forward_cls.forward(). Compare with EAGLE which usesself._base_forward(). The training path base model call at line 446 also usessuper().forward()— verify this doesn't re-enter the DFlash training forward. -
[Correctness]
on_step_endbarrier placement (eagle_utils.pylines 261-263): The barrier runs outside theif is_master()block, meaning all ranks hit it. But non-master ranks skip the validation entirely and proceed straight to the barrier. Ifvalidate_aron the master is slow, other ranks will block at the barrier, which is the intended behavior. However, ifvalidate_arraises an exception on the master, the master skips the barrier (caught by try/except), causing a deadlock on all non-master ranks. The barrier should be inside afinallyblock. -
[Readability] Excessive
print()statements inhf_dflash.py(lines 356, 376, 378, ~530-540): Production code should useprint_rank_0orlogginginstead of bareprint(). The debug prints inpseudo_speculative_generate(lines 520-535) with_psg_debugflag are development leftovers. -
[Readability]
hf_dflash.pyline 362 —_attn_implementationaccess: Accessingself.dflash_config._attn_implementationuses a private attribute ofPretrainedConfig. This is fragile across transformers versions. -
[Duplicated Code] Plugin registration pattern: The PR registers
hf_dflashwithimport_plugin("hf_dflash")inplugins/__init__.py, buthf_dflashis not a third-party dependency — it's always available. The existing plugins useimport_plugin("transformers")andimport_plugin("megatron_eagle")because those depend on optional packages.hf_dflashdepends ontransformers(same as the existingtransformersplugin). It should either be: (a) imported inside the existingwith import_plugin("transformers")block, or (b) imported unconditionally. Usingimport_plugin("hf_dflash")will silently swallowImportErrors that should be real errors. -
[Tests] No test for
pseudo_speculative_generate: The GPU tests cover training forward but not the generation path. The unit tests cover masks/config but not generation. Given generation is a key feature, at least one test verifying the output shape would be valuable. -
[Tests]
ar_validate.shline 79 usestrust_remote_code=True: The PR checklist mentions following security best practices, and the pre-merge check for security anti-patterns somehow passed, buttrust_remote_code=Trueappears in the validation script.
Suggestions:
- The
DFlashConfig.dflash_self_logit_distillationdefaults toTruein the config class but theDFlashArgumentsinmain.pydefaults toFalse. This inconsistency will confuse users — align them. - The
dflash_architecture_configfield inDFlashConfigdefaults to{}(mutable default). Usedefault_factory=dictor similar pattern to avoid shared mutable defaults. - Consider adding docstring coverage for the public functions to address the pre-merge check failure (60.61% vs 80% required).
Overall Assessment: The architecture is sound and follows existing ModelOpt patterns well. However, there are several correctness concerns: the global mutable state for model components is a design flaw that will cause issues in multi-model scenarios, the answer_only_loss implementation in the data collator duplicates existing tokenizer functionality with fragile regex, and the barrier placement in on_step_end can deadlock. The debug prints should be cleaned up before merge.
Implement DFlash (Block Diffusion for Flash Speculative Decoding) as a new mode in ModelOpt's speculative decoding framework.
Key architecture:
Files:
Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036)
What does this PR do?
Type of change: ?
Usage
# Add a code snippet demonstrating how to use thisTesting
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
New Features
Bug Fixes