-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Replaces ModuleSpec with Protocols for some of the inputs to SelfAttention/CrossAttention #2761
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Rebased to capture other PR |
| attn_mask_type = inputs[5] | ||
| attn_mask_type = AttnMaskType(attn_mask_type.item()) | ||
| output_ = self.core_attention( | ||
| output_ = apply_module(self.core_attention)( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ericharper @jaredcasper - for any thoughts on the usage of apply_module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To give my thoughts on apply_module:
- Firstly, I won't be offended by trying to find an alternative (it works but I recognize it's odd)
- To me, the main drawbacks it has are that it can be distracting and ends up being in a bunch of places. However, I liken it to
castor little things like that - people used to the codebase very quickly start ignoring it and knowing when to use it without effort. (That's what happened in our old fork of Megatron anyway) - As a major plus during the migration, it forces me to show you every place where a particular module is being used because it gives me a reason to wrap every call, which might help spot unexpected diffs.
- In general, it makes it so each
Protocolonly needs to show exactly what it means without needing to follow any additional magical patterns (like the only potential alternative, shown below), and the weird complexity is instead wrapped in a unified interface that effectively becomes a 'code-base wide' central link back to a place that can explain what's going on. - Finally,
apply_moduleworks even in cases where a module's type is known, while alternatives that rely onProtocoldo not. For instance, in some submodules ofTransformerLayer, the correct type isBuilderProtocol | type[IdentityOp];apply_moduleis able to understand this, but if you try to use__call__in this case you will get an untyped result becauseIdentityOp.__call__is untyped. More broadly, usingapply_moduleeverywhere (even when you're using a non-configurable module with a fixed type) has nice typing benefits you can't easily get otherwise.
The only alternative idea that doesn't require providing the entire signature twice is to redefine __call__ based on forward like this:
class Unary(Protocol):
def forward(self, x: torch.Tensor, /) -> torch.Tensor: ...
__call__ = forwardI mentioned this might be workable with Yash, but after some research I think it's a bad idea. It works by happenstance, but has three pretty significant drawbacks:
- If you flip the order (define the signature for
__call__and useforward = __call__) then thisProtocolwill silently not type-check correctly (it will accept anyModulewithout complaint). - It is doing something much more subtle than it appears - it is not equivalent to 'copying the signature of
forwardto__call__', it is the equivalent of providing a default definition for__call__which is (in fact) possible to use. Just in broad terms, doing this would leave around a rather misleading/dangerous pattern for people to misunderstand. - As an explicit example of where this could actually be a bug, if you explicitly extend that protocol as in
class Example(Unary, torch.nn.Module)(note the order of subclasses), this would actually break the Module and override torch's builtin__call__behavior to simply callforwarddirectly instead.
| ) | ||
|
|
||
| try: | ||
| if TYPE_CHECKING: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
neat :)
|
Thank you for also fixing some incorrect type hints! Overall lgtm. I'll let @jaredcasper @ericharper comment on if any thoughts on the usage of apply_module in the forward passes of modules |
What does this PR do ?
Introduces
Protocolsfor the first half of the inputs toSelfAttentionandCrossAttention.Associated design doc: https://docs.google.com/document/d/1shyv0iKEzRdevLOlouF_NktbdJazvWifqxUwPXFigQE/edit?tab=t.0#heading=h.uwes2zo47yg6
Stopped here just for ease of review/to keep things more focused; as more typing is corrected throughout the codebase, PRs will slowly get smaller since there's less to fix.
Notes
not_noneutility because the patterns used by config objects and imports make it a common need; it's optional though, I can drop it if it doesn't look useful.attention_maskis the most interesting one, and I've verified that it is indeed allowed to beNonein the implementations.TYPE_CHECKINGescape hatch in the TE extension file, no type checking can be performed on any TE code becausete.pytorch.XXXis not known to be a class so all "subclasses" in the file make the typechecker really unhappy. This doesn't affect any real code though.Contribution process
flowchart LR A[Pre-checks] --> B[PR Tests] subgraph Code Review/Approval C1[Expert Review] --> C2[Final Review] end B --> C1 C2 --> D[Merge]Pre-checks
Core 0.8)Code review
The following process is enforced via the CODEOWNERS file for changes into
megatron/core. For changes outside ofmegatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.For MRs into `main` branch
(Step 1): Add PR label
Expert Review(Step 2): Collect the expert reviewers reviews
Expert Reviewlabel when your PR is ready for review.Final Review might get declined if these requirements are not fulfilled.
(Step 3): Final Review
Final Reviewlabel(Optional Step 4): Cherry-pick into release branch
If this PR also needs to be merged into
core_r*release branches, after this PR has been merged, selectCherry-pickto open a new PR into the release branch.For MRs into `dev` branch
The proposed review process for `dev` branch is under active discussion.MRs are mergable after one approval by either
[email protected]or[email protected].Merging your PR
Any member of core-adlr and
core-nemowill be able to merge your PR.