Skip to content

feat(vision): add Vision DP for parallel ViT computation across SP ranks#5230

Open
aoshen524 wants to merge 1 commit intoverl-project:mainfrom
aoshen524:feat/vision-dp-parallel
Open

feat(vision): add Vision DP for parallel ViT computation across SP ranks#5230
aoshen524 wants to merge 1 commit intoverl-project:mainfrom
aoshen524:feat/vision-dp-parallel

Conversation

@aoshen524
Copy link
Contributor

@aoshen524 aoshen524 commented Feb 7, 2026

Summary

  • Adds Vision Data Parallel (DP) to distribute whole images across Ulysses SP ranks for parallelized ViT computation
  • Hugely reduces ViT memory overhead: without this, every SP rank redundantly processes ALL images through the VisionTransformer. With Vision DP, each rank only processes 1/sp_size of the images, reducing ViT peak memory by ~sp_sizex (e.g. SP=4 → ~4x ViT memory reduction)
  • When ulysses_sp_size > 1, each rank processes a subset of images independently, then all-gathers embeddings once at the end
  • This avoids breaking cu_seqlens semantics (which would happen if patches within images were split across ranks)
  • Model-agnostic create_dp_vision_forward() wrapper supports any VisionTransformer with forward(self, hidden_states, grid_thw) signature
  • Supports Qwen2-VL, Qwen2.5-VL, Qwen3-VL, and Qwen3-VL-MoE VisionTransformers
  • Includes GatherVisionEmbeddings custom autograd function with proper gradient scaling for FSDP/DDP compatibility

Why this matters

In VLM RL training with Ulysses SP, the ViT (VisionTransformer) is a major memory bottleneck. Text SP splits the sequence across ranks at each attention layer, but the ViT runs on the full set of images on every rank — meaning ViT memory usage is completely unaffected by SP. For scenarios with many images (e.g. multi-turn GUI agent training with screenshots), ViT activation memory can dominate.

Vision DP solves this by distributing images at the ViT level:

  • Before: Each of N SP ranks processes ALL images → ViT memory = O(total_images)
  • After: Each rank processes total_images/N images → ViT memory = O(total_images/N)

Key design choices

  • Image-level distribution (not patch-level): avoids breaking ViT's internal cu_seqlens tracking
  • Contiguous assignment: rank 0 gets images [0,1,...], rank 1 gets next chunk, etc. — no reordering needed after all-gather
  • Gradient scaling: backward pass scales gradients by dp_size to compensate for partial image processing before FSDP/DDP reduction

Test plan

  • 17 unit tests covering all utility functions (get_image_patch_counts, assign_images_to_dp_ranks, prepare_local_vision_inputs)
  • Integration tests for full workflow with varying image sizes
  • Edge cases: empty inputs, fewer images than ranks, single rank
  • Multi-GPU integration test with actual VLM model

🤖 Generated with Claude Code

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Vision Data Parallelism (DP) to enable parallel ViT computation by distributing images across sequence parallel ranks. The implementation is well-structured, with new utilities in verl/utils/vision_dp.py and corresponding monkey-patching logic. The addition of comprehensive unit tests in tests/test_vision_dp.py is commendable.

My review focuses on improving code quality, robustness, and maintainability. I've identified a few areas with high-severity issues:

  • Redundant computation in prepare_local_vision_inputs.
  • Fragile logic for determining hidden_size in the DP wrapper, which could lead to runtime errors with different models.
  • Significant code duplication in monkey_patch.py that should be refactored.

Addressing these points will make the new functionality more robust and easier to maintain in the future.

Comment on lines +407 to +437
if ulysses_sp_size > 1:
from verl.utils.vision_dp import create_dp_vision_forward

# Patch Qwen2-VL VisionTransformer
try:
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel

original_vision_forward = Qwen2VisionTransformerPretrainedModel.forward
Qwen2VisionTransformerPretrainedModel.forward = create_dp_vision_forward(original_vision_forward)
print(
f"Monkey patch Qwen2VisionTransformerPretrainedModel.forward"
f" for Vision DP (dp_size={ulysses_sp_size})"
)
except ImportError as e:
print(f"Warning: Could not patch Qwen2VisionTransformer for Vision DP: {e}")

# Patch Qwen2.5-VL VisionTransformer (uses a different class)
try:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionTransformerPretrainedModel,
)

original_vision_forward_25 = Qwen2_5_VisionTransformerPretrainedModel.forward
Qwen2_5_VisionTransformerPretrainedModel.forward = create_dp_vision_forward(original_vision_forward_25)
print(
f"Monkey patch Qwen2_5_VisionTransformerPretrainedModel.forward"
f" for Vision DP (dp_size={ulysses_sp_size})"
)
except ImportError as e:
print(f"Warning: Could not patch Qwen2_5VisionTransformer for Vision DP: {e}")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code for monkey-patching the vision transformer is repeated multiple times in this file (here for Qwen2/2.5-VL, and again for Qwen3 models). This duplication makes the code harder to read and maintain. Please consider refactoring this logic into a helper function that takes the module path and class name as arguments. This would significantly reduce code duplication and improve maintainability.

For example, a helper could look like this:

def _patch_vision_model_for_dp(module_path, class_name, ulysses_sp_size):
    try:
        module = __import__(module_path, fromlist=[class_name])
        model_class = getattr(module, class_name)
        original_forward = model_class.forward
        model_class.forward = create_dp_vision_forward(original_forward)
        print(f"Monkey patch {class_name}.forward for Vision DP (dp_size={ulysses_sp_size})")
    except (ImportError, AttributeError) as e:
        print(f"Warning: Could not patch {class_name} for Vision DP: {e}")

Additionally, the step numbering is inconsistent ("Step 4" here, but "Step 3" for the Qwen3 models). A refactor would help resolve such inconsistencies.

)

# Compute patch offsets for each image
patch_counts = (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).tolist()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The patch_counts are recalculated here, but they have already been computed in the calling function create_dp_vision_forward. This leads to redundant computation. You can improve efficiency and code clarity by passing patch_counts as an argument to prepare_local_vision_inputs and removing this line.

Comment on lines +387 to +406
# This rank has no images, create empty tensor with correct hidden size
# Try multiple common attribute paths for hidden size detection
if hasattr(self, "merger") and hasattr(self.merger, "ln_q"):
ln_q = self.merger.ln_q
if hasattr(ln_q, "normalized_shape"):
hidden_size = ln_q.normalized_shape[0]
elif hasattr(ln_q, "weight"):
hidden_size = ln_q.weight.shape[0]
else:
raise RuntimeError(f"Cannot determine hidden_size from ln_q. Type: {type(ln_q).__name__}")
elif hasattr(self, "out_hidden_size"):
hidden_size = self.out_hidden_size
elif hasattr(self, "config") and hasattr(self.config, "hidden_size"):
hidden_size = self.config.hidden_size
else:
raise RuntimeError(
f"Cannot determine hidden_size for VisionTransformer. "
f"Model type: {type(self).__name__}. "
f"Available attributes: {[a for a in dir(self) if not a.startswith('_')]}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic to determine hidden_size for ranks with no images is complex and relies on a series of hasattr checks for specific model attributes. This approach is brittle and may break with new or refactored models, leading to RuntimeError. A more robust approach would be to make this explicit. For example, create_dp_vision_forward could accept the hidden_size or a hidden_size_getter function as an argument. This would be provided at the patching site where the model context is clear, making the wrapper more reliable and easier to maintain.

When Ulysses sequence parallelism is enabled (sp_size > 1), the
VisionTransformer processes all images on every rank redundantly.
This adds Vision Data Parallel (DP) which distributes whole images
across SP ranks for independent ViT processing, then all-gathers
embeddings once at the end — reducing ViT peak memory by ~sp_size x.

Also removes the forced eager attention fallback for ViT when SP>1,
since Vision DP makes each rank process only its local images and
the _ulysses_flash_attention_forward already correctly skips ViT
(via position_ids is None guard).

Supports Qwen2-VL, Qwen2.5-VL, Qwen3-VL, and Qwen3-VL-MoE.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@aoshen524 aoshen524 force-pushed the feat/vision-dp-parallel branch from 76ba823 to 8763b64 Compare February 7, 2026 13:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant