Skip to content

[PyTorch] [torch.compile] Split linear forward into forward and setup context.#2811

Open
pggPL wants to merge 4 commits intoNVIDIA:mainfrom
pggPL:linear_split_into_setup_ctx_and_forward
Open

[PyTorch] [torch.compile] Split linear forward into forward and setup context.#2811
pggPL wants to merge 4 commits intoNVIDIA:mainfrom
pggPL:linear_split_into_setup_ctx_and_forward

Conversation

@pggPL
Copy link
Copy Markdown
Collaborator

@pggPL pggPL commented Mar 30, 2026

Description

This PR refactors linear.py to easier add torch.compile support in next PRs.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL added 2 commits March 30, 2026 14:22
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…etup_ctx_and_forward

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

# Conflicts:
#	transformer_engine/pytorch/module/linear.py
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 30, 2026

Greptile Summary

This PR refactors transformer_engine/pytorch/module/linear.py to split the monolithic _Linear.forward method into three separate top-level functions (_linear_forward_impl, _linear_setup_ctx, _linear_backward) in preparation for torch.compile support in follow-on PRs.

Key structural changes:

  • _linear_forward_impl performs the actual forward GEMM and returns (out, tensors_to_save, tensor_objects, ctx_attrs), making it callable without an autograd context.
  • _linear_setup_ctx saves forward state into the autograd ctx, called only when an autograd context exists.
  • _linear_backward contains the backward GEMM logic, now accepts quantizers as explicit parameters instead of reading them from ctx.
  • Six quantizer arguments are moved out of non_tensor_args and passed explicitly to _Linear.forward, with the backward return tuple extended to 10 values matching all 10 forward inputs.
  • A _check_fp8_reduce_and_update() helper is extracted to encapsulate the FP8 first-module state management.

Previously-flagged issues now resolved: backward gradient count mismatch, dead-code UB reset lines, and undefined assert_dim_for_all_gather.

One remaining concern: the or debug condition was silently dropped from the weight-quantizer guard (see inline comment), which changes behaviour in debug mode with pre-quantized weights.

Confidence Score: 4/5

Safe to merge with one P1 behavioural change that should be confirmed intentional before landing.

The three previously-flagged critical issues (backward gradient count, dead UB-reset code, undefined assert_dim_for_all_gather) are all resolved. The refactoring structure is sound, the non_tensor_args tuple ordering is consistent across all three new helper functions, and the ctx setup/access patterns are correct. One P1 concern remains: the or debug condition was silently removed from the weight-quantizer guard in _linear_forward_impl, changing behaviour in debug mode with pre-quantized QuantizedTensor weights. Until this is confirmed intentional (or fixed), the score cannot reach 5.

transformer_engine/pytorch/module/linear.py — specifically _linear_forward_impl around line 252 where or debug was dropped from the weight-quantizer condition.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/linear.py Refactors _Linear by extracting _linear_forward_impl, _linear_setup_ctx, and _linear_backward; correctly adds 6 new quantizer arguments to forward/backward; one subtle behavior change: drops the or debug condition from the weight-quantizer guard, silently bypassing the debug quantizer when weight is already a QuantizedTensor.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["Linear.forward()"] --> B{is_grad_enabled?}
    B -- Yes --> C["_Linear.apply()"]
    B -- No --> D["_Linear.forward(ctx=None)"]
    C --> E["_Linear.forward(ctx)"]
    D --> E
    E --> F["_linear_forward_impl(weight, inp, bias, non_tensor_args, input_q, weight_q, output_q)"]
    F --> G["returns (out, tensors_to_save, tensor_objects, ctx_attrs)"]
    G --> H{ctx is not None?}
    H -- Yes --> I["_linear_setup_ctx(ctx, ...)"]
    H -- No --> J["return out"]
    I --> K["Set ctx.fp8, ctx.quantizers, ctx.ub_flags, etc."]
    K --> L{fp8 and requires_grad?}
    L -- Yes --> M["ctx.reduce_and_update_bwd_fp8_tensors = _check_fp8_reduce_and_update()"]
    L -- No --> N["ctx.reduce_and_update_bwd_fp8_tensors = False"]
    M --> J
    N --> J
    J --> O["_Linear.backward(ctx, grad_output)"]
    O --> P["_linear_backward(ctx, grad_output, input_q, weight_q, grad_*_q)"]
    P --> Q["returns (wgrad, dgrad, grad_bias, None x7)"]
    Q --> R{reduce_and_update_bwd_fp8_tensors?}
    R -- Yes --> S["FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)"]
    R -- No --> T["return result"]
    S --> T
Loading

Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +1034 to +1039
return (
wgrad,
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
grad_bias,
None,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Backward returns too few gradients for new forward inputs

_Linear.forward now accepts 10 positional inputs (excluding ctx): weight, inp, bias, non_tensor_args, input_quantizer, weight_quantizer, output_quantizer, grad_input_quantizer, grad_weight_quantizer, grad_output_quantizer.

PyTorch's autograd engine requires backward to return one value per forward input (even non-tensor ones get None). The current return tuple has only 4 values, which will cause a runtime error at backprop time:

RuntimeError: function _LinearBackward returned an incorrect number of gradients (expected 10, got 4)

Six None values (one per quantizer argument) are missing from the return:

Suggested change
return (
wgrad,
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
grad_bias,
None,
)
return (
wgrad,
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
grad_bias,
None, # non_tensor_args
None, # input_quantizer
None, # weight_quantizer
None, # output_quantizer
None, # grad_input_quantizer
None, # grad_weight_quantizer
None, # grad_output_quantizer
)

Comment on lines +157 to +159
ub_overlap_rs_dgrad = False
ub_bulk_wgrad = False
ub_bulk_dgrad = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Dead code in debug-mode UB reset block

_linear_forward_impl unpacks the dgrad/wgrad UB flags from non_tensor_args as _ub_overlap_rs_dgrad, _ub_bulk_dgrad, and _ub_bulk_wgrad (prefixed with _ to signal they are intentionally unused). The three assignments here create entirely new local variables that shadow nothing and are never read again within this function, making them dead code.

The corresponding adjustments for the backward path are correctly applied in _linear_setup_ctx (which re-reads the same flags from non_tensor_args and applies the debug override before writing to ctx). These three lines can safely be removed.

Suggested change
ub_overlap_rs_dgrad = False
ub_bulk_wgrad = False
ub_bulk_dgrad = False
if debug: # turn off userbuffers in debug mode
ub_overlap_rs_fprop = False
ub_overlap_ag_fprop = False

own_quantized_input = False
if fp8:
assert_dim_for_fp8_exec(inputmat, weight)
assert_dim_for_all_gather(inputmat, with_input_all_gather_nccl, input_quantizer)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Undefined function assert_dim_for_all_gather will cause NameError at runtime

assert_dim_for_all_gather is called here but is not defined anywhere in the codebase and was not present in the original _Linear.forward. A search of the entire repository confirms there is no definition or import of this function in any Python file.

In the original code (pre-PR), only assert_dim_for_fp8_exec was used at this point in the FP8 block. assert_dim_for_all_gather appears to be a new helper that was intended to be added alongside this refactor, but was inadvertently left out.

Any FP8-enabled forward pass will raise:

NameError: name 'assert_dim_for_all_gather' is not defined

You need to either define assert_dim_for_all_gather (importing or implementing it) or remove this call if no additional dimension-check logic was intended for the all-gather path.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented Mar 30, 2026

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant