Skip to content

Conversation

@littsk
Copy link

@littsk littsk commented Dec 24, 2025

What does this PR do ?

[Refactor] Introduce ContextParallelHandler for Unified Context Parallelism Abstraction

1. Design Philosophy: The ContextParallelHandler

This PR introduces the ContextParallelHandler, a unified abstraction layer designed to manage Context Parallelism (CP) operations across different attention backends and data formats. The primary goal is to decouple the parallelization strategy from the core model architecture, ensuring both extensibility and runtime efficiency.

We have refactored the context-parallel API to improve modularity and maintainability. The abstraction is encapsulated into two distinct categories:

  • Communication Primitives: Handles underlying distributed operations (e.g., dispatch, combine) and exposes high-level hooks for seamless interaction.
  • Non-intrusive Integration: Abstracts the logic required when Context Parallel interacts with model internals(e.g. rope, roll_tensor). This ensures a non-intrusive integration strategy, decoupling CP logic from the model definition to keep the architecture clean.

2. Architectural Changes

The core design revolves around the ContextParallelHandler Abstract Base Class (ABC):

  • Unified Interface: Defines a standard contract for essential CP operations (e.g., dispatch, combine, roll_tensor, apply_rotary_pos_emb). This allows Transformer layers to remain agnostic to underlying CP implementation details.
  • Backend Agnosticism: By subclassing the handler (e.g., DefaultContextParallelHandler, MagiAttnContextParallelHandler), we can seamlessly switch between different execution backends (Transformer Engine, Local, or MagiAttention) via a factory pattern (get_cp_handler_cls), without altering model code.

3. Implementation Strategy & Migration Plan

The ultimate goal is to replace all instances of PackedSeqParams with ContextParallelHandler. By encapsulating all CP-dependent logic within backend-specific handlers, we achieve a truly backend-agnostic Context Parallel implementation.

Given the significant scope of this refactor, we are adopting a two-step migration strategy:

Step 1 (This PR)

  • Core Implementation: Implement the backend-agnostic ContextParallelHandler for the primary components of the GPTModel.
  • Legacy Compatibility: For peripheral components not yet fully refactored, the ContextParallelHandler is implemented to ensure strict semantic equivalence with the original logic. This ensures backward compatibility with legacy code while establishing the new infrastructure.

Step 2 (Future Work)

  • Full Migration: Gradually refactor the remaining components to fully utilize the encapsulation provided by ContextParallelHandler, removing direct dependencies on legacy structures.

4. Refactored APIs

The following APIs have been modified or refactored:

  1. Modified: get_batch_on_this_cp_rank (megatron/core/utils.py)
    • Now leverages the handler to manage batch slicing across ranks.
  2. Modified: RotaryEmbedding.forward (megatron/core/models/common/embeddings/rotary_pos_embedding.py)
    • Updated to accept the handler for CP-aware embedding application.
  3. Removed: get_thd_batch_on_this_cp_rank (megatron/core/utils.py)
    • This logic is now subsumed by the unified ContextParallelHandler.

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share discuss a design-doc with the team.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either [email protected] or [email protected].

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@littsk littsk requested review from a team as code owners December 24, 2025 05:44
@copy-pr-bot
Copy link

copy-pr-bot bot commented Dec 24, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@github-actions github-actions bot requested a review from Phlip79 December 24, 2025 05:45
@littsk littsk changed the title [WIP][DO NOT MERGE][REFACTOR] Introduce ContextParallelHandler for Unified Context Parallelism Abstraction [DEV][WIP][DO NOT MERGE][REFACTOR] Introduce ContextParallelHandler for Unified Context Parallelism Abstraction Dec 24, 2025
@Phlip79 Phlip79 removed their request for review December 24, 2025 17:39
@littsk
Copy link
Author

littsk commented Dec 25, 2025

Basic Usage

The following examples demonstrate how the ContextParallelHandler is initialized and integrated into the model workflow to abstract away parallelization logic.

1. Initialization and Batch Dispatch

In the training loop, the specific handler class is selected based on the backend (e.g., transformer_engine, magi-attention) and instantiated. It is then used to slice the global batch for the current Context Parallel rank.

def get_batch(...):
    # 1. Select the appropriate handler class based on arguments
    backend = args.transformer_impl
    cp_comm_type = args.cp_comm_type
    cp_handler_cls = get_cp_handler_cls(backend=backend, cp_comm_type=cp_comm_type)
    
    # 2. Instantiate the handler (assuming necessary metadata is passed here)
    cp_handler = cp_handler_cls(...) 

    # 3. Use the handler to dispatch/slice the batch for the current CP rank
    batch = get_batch_on_this_cp_rank(batch, cp_handler=cp_handler)
    
    return batch

2. Model Forward Pass

The cp_handler is passed down through the model's forward method, allowing sub-components to access CP logic without direct coupling.

class GPTModel(nn.Module):
    def forward(self, ...):
        # Calculate sequence length requirements specific to the CP strategy
        self.rotary_pos_emb.get_rotary_seq_len(
            inference_context, self.decoder, decoder_input, self.config, cp_handler
        )

        # Generate Rotary Embeddings using the handler
        rotary_pos_emb = self.rotary_pos_emb(
            position_ids, self.mrope_section, cp_handler=cp_handler
        )
        
        # ... pass cp_handler to layers ...

3. Component Abstraction (RoPE & Attention)

Inside specific components, the handler manages communication (dispatch/combine) and backend-specific operations.

class RotaryEmbedding(nn.Module):
    def forward(self, ...):
        # ... existing logic ...
        
        # Slice the embedding for the current CP rank (if necessary)
        emb = cp_handler.get_emb_on_this_cp_rank(emb)
        
        return emb

    def get_rotary_seq_len(self, ...):
        # Delegate sequence length calculation to the handler
        rotary_seq_len = cp_handler.get_rotary_seq_len(...)
        return rotary_seq_len


class Attention(nn.Module):
    def forward(self, ...):
        # ... logic ...

        # Apply RoPE using the handler (handles offsets and CP specifics)
        query = cp_handler.apply_rotary_pos_emb(
            query, q_pos_emb, config=self.config
        )
        key = cp_handler.apply_rotary_pos_emb(
            key, k_pos_emb, config=self.config
        )

        # Execute Core Attention 
        # The handler manages the dispatch -> attention -> combine flow internally
        out = cp_handler.core_attn(
            attn_mod=self.core_attention,
            query=query,
            key=key,
            value=value,
            # ... other args
        )
        
        return out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants