-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[DEV][WIP][DO NOT MERGE][REFACTOR] Introduce ContextParallelHandler for Unified Context Parallelism Abstraction #2749
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
Basic UsageThe following examples demonstrate how the 1. Initialization and Batch DispatchIn the training loop, the specific handler class is selected based on the backend (e.g., def get_batch(...):
# 1. Select the appropriate handler class based on arguments
backend = args.transformer_impl
cp_comm_type = args.cp_comm_type
cp_handler_cls = get_cp_handler_cls(backend=backend, cp_comm_type=cp_comm_type)
# 2. Instantiate the handler (assuming necessary metadata is passed here)
cp_handler = cp_handler_cls(...)
# 3. Use the handler to dispatch/slice the batch for the current CP rank
batch = get_batch_on_this_cp_rank(batch, cp_handler=cp_handler)
return batch2. Model Forward PassThe class GPTModel(nn.Module):
def forward(self, ...):
# Calculate sequence length requirements specific to the CP strategy
self.rotary_pos_emb.get_rotary_seq_len(
inference_context, self.decoder, decoder_input, self.config, cp_handler
)
# Generate Rotary Embeddings using the handler
rotary_pos_emb = self.rotary_pos_emb(
position_ids, self.mrope_section, cp_handler=cp_handler
)
# ... pass cp_handler to layers ...3. Component Abstraction (RoPE & Attention)Inside specific components, the handler manages communication (dispatch/combine) and backend-specific operations. class RotaryEmbedding(nn.Module):
def forward(self, ...):
# ... existing logic ...
# Slice the embedding for the current CP rank (if necessary)
emb = cp_handler.get_emb_on_this_cp_rank(emb)
return emb
def get_rotary_seq_len(self, ...):
# Delegate sequence length calculation to the handler
rotary_seq_len = cp_handler.get_rotary_seq_len(...)
return rotary_seq_len
class Attention(nn.Module):
def forward(self, ...):
# ... logic ...
# Apply RoPE using the handler (handles offsets and CP specifics)
query = cp_handler.apply_rotary_pos_emb(
query, q_pos_emb, config=self.config
)
key = cp_handler.apply_rotary_pos_emb(
key, k_pos_emb, config=self.config
)
# Execute Core Attention
# The handler manages the dispatch -> attention -> combine flow internally
out = cp_handler.core_attn(
attn_mod=self.core_attention,
query=query,
key=key,
value=value,
# ... other args
)
return out |
What does this PR do ?
[Refactor] Introduce ContextParallelHandler for Unified Context Parallelism Abstraction
1. Design Philosophy: The ContextParallelHandler
This PR introduces the
ContextParallelHandler, a unified abstraction layer designed to manage Context Parallelism (CP) operations across different attention backends and data formats. The primary goal is to decouple the parallelization strategy from the core model architecture, ensuring both extensibility and runtime efficiency.We have refactored the context-parallel API to improve modularity and maintainability. The abstraction is encapsulated into two distinct categories:
dispatch,combine) and exposes high-level hooks for seamless interaction.rope,roll_tensor). This ensures a non-intrusive integration strategy, decoupling CP logic from the model definition to keep the architecture clean.2. Architectural Changes
The core design revolves around the
ContextParallelHandlerAbstract Base Class (ABC):dispatch,combine,roll_tensor,apply_rotary_pos_emb). This allows Transformer layers to remain agnostic to underlying CP implementation details.DefaultContextParallelHandler,MagiAttnContextParallelHandler), we can seamlessly switch between different execution backends (Transformer Engine, Local, or MagiAttention) via a factory pattern (get_cp_handler_cls), without altering model code.3. Implementation Strategy & Migration Plan
The ultimate goal is to replace all instances of
PackedSeqParamswithContextParallelHandler. By encapsulating all CP-dependent logic within backend-specific handlers, we achieve a truly backend-agnostic Context Parallel implementation.Given the significant scope of this refactor, we are adopting a two-step migration strategy:
Step 1 (This PR)
ContextParallelHandlerfor the primary components of theGPTModel.ContextParallelHandleris implemented to ensure strict semantic equivalence with the original logic. This ensures backward compatibility with legacy code while establishing the new infrastructure.Step 2 (Future Work)
ContextParallelHandler, removing direct dependencies on legacy structures.4. Refactored APIs
The following APIs have been modified or refactored:
get_batch_on_this_cp_rank(megatron/core/utils.py)RotaryEmbedding.forward(megatron/core/models/common/embeddings/rotary_pos_embedding.py)get_thd_batch_on_this_cp_rank(megatron/core/utils.py)ContextParallelHandler.Contribution process
flowchart LR A[Pre-checks] --> B[PR Tests] subgraph Code Review/Approval C1[Expert Review] --> C2[Final Review] end B --> C1 C2 --> D[Merge]Pre-checks
Core 0.8)Code review
The following process is enforced via the CODEOWNERS file for changes into
megatron/core. For changes outside ofmegatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.For MRs into `main` branch
(Step 1): Add PR label
Expert Review(Step 2): Collect the expert reviewers reviews
Expert Reviewlabel when your PR is ready for review.Final Review might get declined if these requirements are not fulfilled.
(Step 3): Final Review
Final Reviewlabel(Optional Step 4): Cherry-pick into release branch
If this PR also needs to be merged into
core_r*release branches, after this PR has been merged, selectCherry-pickto open a new PR into the release branch.For MRs into `dev` branch
The proposed review process for `dev` branch is under active discussion.MRs are mergable after one approval by either
[email protected]or[email protected].Merging your PR
Any member of core-adlr and
core-nemowill be able to merge your PR.