[PyTorch] [torch.compile] Split linear forward into forward and setup context.#2811
[PyTorch] [torch.compile] Split linear forward into forward and setup context.#2811pggPL wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
…etup_ctx_and_forward Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> # Conflicts: # transformer_engine/pytorch/module/linear.py
Greptile SummaryThis PR refactors Key structural changes:
Previously-flagged issues now resolved: backward gradient count mismatch, dead-code UB reset lines, and undefined One remaining concern: the Confidence Score: 4/5Safe to merge with one P1 behavioural change that should be confirmed intentional before landing. The three previously-flagged critical issues (backward gradient count, dead UB-reset code, undefined assert_dim_for_all_gather) are all resolved. The refactoring structure is sound, the non_tensor_args tuple ordering is consistent across all three new helper functions, and the ctx setup/access patterns are correct. One P1 concern remains: the transformer_engine/pytorch/module/linear.py — specifically _linear_forward_impl around line 252 where Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["Linear.forward()"] --> B{is_grad_enabled?}
B -- Yes --> C["_Linear.apply()"]
B -- No --> D["_Linear.forward(ctx=None)"]
C --> E["_Linear.forward(ctx)"]
D --> E
E --> F["_linear_forward_impl(weight, inp, bias, non_tensor_args, input_q, weight_q, output_q)"]
F --> G["returns (out, tensors_to_save, tensor_objects, ctx_attrs)"]
G --> H{ctx is not None?}
H -- Yes --> I["_linear_setup_ctx(ctx, ...)"]
H -- No --> J["return out"]
I --> K["Set ctx.fp8, ctx.quantizers, ctx.ub_flags, etc."]
K --> L{fp8 and requires_grad?}
L -- Yes --> M["ctx.reduce_and_update_bwd_fp8_tensors = _check_fp8_reduce_and_update()"]
L -- No --> N["ctx.reduce_and_update_bwd_fp8_tensors = False"]
M --> J
N --> J
J --> O["_Linear.backward(ctx, grad_output)"]
O --> P["_linear_backward(ctx, grad_output, input_q, weight_q, grad_*_q)"]
P --> Q["returns (wgrad, dgrad, grad_bias, None x7)"]
Q --> R{reduce_and_update_bwd_fp8_tensors?}
R -- Yes --> S["FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)"]
R -- No --> T["return result"]
S --> T
Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| return ( | ||
| wgrad, | ||
| dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, | ||
| grad_bias, | ||
| None, | ||
| ) |
There was a problem hiding this comment.
Backward returns too few gradients for new forward inputs
_Linear.forward now accepts 10 positional inputs (excluding ctx): weight, inp, bias, non_tensor_args, input_quantizer, weight_quantizer, output_quantizer, grad_input_quantizer, grad_weight_quantizer, grad_output_quantizer.
PyTorch's autograd engine requires backward to return one value per forward input (even non-tensor ones get None). The current return tuple has only 4 values, which will cause a runtime error at backprop time:
RuntimeError: function _LinearBackward returned an incorrect number of gradients (expected 10, got 4)
Six None values (one per quantizer argument) are missing from the return:
| return ( | |
| wgrad, | |
| dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, | |
| grad_bias, | |
| None, | |
| ) | |
| return ( | |
| wgrad, | |
| dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, | |
| grad_bias, | |
| None, # non_tensor_args | |
| None, # input_quantizer | |
| None, # weight_quantizer | |
| None, # output_quantizer | |
| None, # grad_input_quantizer | |
| None, # grad_weight_quantizer | |
| None, # grad_output_quantizer | |
| ) |
| ub_overlap_rs_dgrad = False | ||
| ub_bulk_wgrad = False | ||
| ub_bulk_dgrad = False |
There was a problem hiding this comment.
Dead code in debug-mode UB reset block
_linear_forward_impl unpacks the dgrad/wgrad UB flags from non_tensor_args as _ub_overlap_rs_dgrad, _ub_bulk_dgrad, and _ub_bulk_wgrad (prefixed with _ to signal they are intentionally unused). The three assignments here create entirely new local variables that shadow nothing and are never read again within this function, making them dead code.
The corresponding adjustments for the backward path are correctly applied in _linear_setup_ctx (which re-reads the same flags from non_tensor_args and applies the debug override before writing to ctx). These three lines can safely be removed.
| ub_overlap_rs_dgrad = False | |
| ub_bulk_wgrad = False | |
| ub_bulk_dgrad = False | |
| if debug: # turn off userbuffers in debug mode | |
| ub_overlap_rs_fprop = False | |
| ub_overlap_ag_fprop = False |
| own_quantized_input = False | ||
| if fp8: | ||
| assert_dim_for_fp8_exec(inputmat, weight) | ||
| assert_dim_for_all_gather(inputmat, with_input_all_gather_nccl, input_quantizer) |
There was a problem hiding this comment.
Undefined function
assert_dim_for_all_gather will cause NameError at runtime
assert_dim_for_all_gather is called here but is not defined anywhere in the codebase and was not present in the original _Linear.forward. A search of the entire repository confirms there is no definition or import of this function in any Python file.
In the original code (pre-PR), only assert_dim_for_fp8_exec was used at this point in the FP8 block. assert_dim_for_all_gather appears to be a new helper that was intended to be added alongside this refactor, but was inadvertently left out.
Any FP8-enabled forward pass will raise:
NameError: name 'assert_dim_for_all_gather' is not defined
You need to either define assert_dim_for_all_gather (importing or implementing it) or remove this call if no additional dimension-check logic was intended for the all-gather path.
|
/te-ci pytorch |
for more information, see https://pre-commit.ci
Description
This PR refactors linear.py to easier add torch.compile support in next PRs.
Fixes # (issue)
Type of change
Checklist: