Skip to content

Conversation

@nschank
Copy link
Contributor

@nschank nschank commented Dec 25, 2025

What does this PR do ?

Introduces Protocols for the first half of the inputs to SelfAttention and CrossAttention.

Associated design doc: https://docs.google.com/document/d/1shyv0iKEzRdevLOlouF_NktbdJazvWifqxUwPXFigQE/edit?tab=t.0#heading=h.uwes2zo47yg6

Stopped here just for ease of review/to keep things more focused; as more typing is corrected throughout the codebase, PRs will slowly get smaller since there's less to fix.

Notes

  • Added a not_none utility because the patterns used by config objects and imports make it a common need; it's optional though, I can drop it if it doesn't look useful.
  • I had to update the return types and argument types of several fundamental modules. These should all be pretty trivial; attention_mask is the most interesting one, and I've verified that it is indeed allowed to be None in the implementations.
  • Note that, without the odd TYPE_CHECKING escape hatch in the TE extension file, no type checking can be performed on any TE code because te.pytorch.XXX is not known to be a class so all "subclasses" in the file make the typechecker really unhappy. This doesn't affect any real code though.
  • Apologies for the autoformatting in the examples - I can revert them if desired, but they all seemed correct so 🤷

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either [email protected] or [email protected].

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@nschank nschank requested review from a team as code owners December 25, 2025 23:22
@copy-pr-bot
Copy link

copy-pr-bot bot commented Dec 25, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@nschank
Copy link
Contributor Author

nschank commented Jan 3, 2026

Rebased to capture other PR

attn_mask_type = inputs[5]
attn_mask_type = AttnMaskType(attn_mask_type.item())
output_ = self.core_attention(
output_ = apply_module(self.core_attention)(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ericharper @jaredcasper - for any thoughts on the usage of apply_module.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To give my thoughts on apply_module:

  • Firstly, I won't be offended by trying to find an alternative (it works but I recognize it's odd)
  • To me, the main drawbacks it has are that it can be distracting and ends up being in a bunch of places. However, I liken it to cast or little things like that - people used to the codebase very quickly start ignoring it and knowing when to use it without effort. (That's what happened in our old fork of Megatron anyway)
  • As a major plus during the migration, it forces me to show you every place where a particular module is being used because it gives me a reason to wrap every call, which might help spot unexpected diffs.
  • In general, it makes it so each Protocol only needs to show exactly what it means without needing to follow any additional magical patterns (like the only potential alternative, shown below), and the weird complexity is instead wrapped in a unified interface that effectively becomes a 'code-base wide' central link back to a place that can explain what's going on.
  • Finally, apply_module works even in cases where a module's type is known, while alternatives that rely on Protocol do not. For instance, in some submodules of TransformerLayer, the correct type is BuilderProtocol | type[IdentityOp]; apply_module is able to understand this, but if you try to use __call__ in this case you will get an untyped result because IdentityOp.__call__ is untyped. More broadly, using apply_module everywhere (even when you're using a non-configurable module with a fixed type) has nice typing benefits you can't easily get otherwise.

The only alternative idea that doesn't require providing the entire signature twice is to redefine __call__ based on forward like this:

class Unary(Protocol):
  def forward(self, x: torch.Tensor, /) -> torch.Tensor: ...

  __call__ = forward

I mentioned this might be workable with Yash, but after some research I think it's a bad idea. It works by happenstance, but has three pretty significant drawbacks:

  1. If you flip the order (define the signature for __call__ and use forward = __call__) then this Protocol will silently not type-check correctly (it will accept any Module without complaint).
  2. It is doing something much more subtle than it appears - it is not equivalent to 'copying the signature of forward to __call__', it is the equivalent of providing a default definition for __call__ which is (in fact) possible to use. Just in broad terms, doing this would leave around a rather misleading/dangerous pattern for people to misunderstand.
  3. As an explicit example of where this could actually be a bug, if you explicitly extend that protocol as in class Example(Unary, torch.nn.Module) (note the order of subclasses), this would actually break the Module and override torch's builtin __call__ behavior to simply call forward directly instead.

)

try:
if TYPE_CHECKING:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neat :)

@yashaswikarnati
Copy link
Contributor

Thank you for also fixing some incorrect type hints! Overall lgtm.

I'll let @jaredcasper @ericharper comment on if any thoughts on the usage of apply_module in the forward passes of modules

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants