feat(vision): add Vision DP for parallel ViT computation across SP ranks#5230
feat(vision): add Vision DP for parallel ViT computation across SP ranks#5230aoshen524 wants to merge 1 commit intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces Vision Data Parallelism (DP) to enable parallel ViT computation by distributing images across sequence parallel ranks. The implementation is well-structured, with new utilities in verl/utils/vision_dp.py and corresponding monkey-patching logic. The addition of comprehensive unit tests in tests/test_vision_dp.py is commendable.
My review focuses on improving code quality, robustness, and maintainability. I've identified a few areas with high-severity issues:
- Redundant computation in
prepare_local_vision_inputs. - Fragile logic for determining
hidden_sizein the DP wrapper, which could lead to runtime errors with different models. - Significant code duplication in
monkey_patch.pythat should be refactored.
Addressing these points will make the new functionality more robust and easier to maintain in the future.
| if ulysses_sp_size > 1: | ||
| from verl.utils.vision_dp import create_dp_vision_forward | ||
|
|
||
| # Patch Qwen2-VL VisionTransformer | ||
| try: | ||
| from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel | ||
|
|
||
| original_vision_forward = Qwen2VisionTransformerPretrainedModel.forward | ||
| Qwen2VisionTransformerPretrainedModel.forward = create_dp_vision_forward(original_vision_forward) | ||
| print( | ||
| f"Monkey patch Qwen2VisionTransformerPretrainedModel.forward" | ||
| f" for Vision DP (dp_size={ulysses_sp_size})" | ||
| ) | ||
| except ImportError as e: | ||
| print(f"Warning: Could not patch Qwen2VisionTransformer for Vision DP: {e}") | ||
|
|
||
| # Patch Qwen2.5-VL VisionTransformer (uses a different class) | ||
| try: | ||
| from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( | ||
| Qwen2_5_VisionTransformerPretrainedModel, | ||
| ) | ||
|
|
||
| original_vision_forward_25 = Qwen2_5_VisionTransformerPretrainedModel.forward | ||
| Qwen2_5_VisionTransformerPretrainedModel.forward = create_dp_vision_forward(original_vision_forward_25) | ||
| print( | ||
| f"Monkey patch Qwen2_5_VisionTransformerPretrainedModel.forward" | ||
| f" for Vision DP (dp_size={ulysses_sp_size})" | ||
| ) | ||
| except ImportError as e: | ||
| print(f"Warning: Could not patch Qwen2_5VisionTransformer for Vision DP: {e}") | ||
|
|
There was a problem hiding this comment.
This block of code for monkey-patching the vision transformer is repeated multiple times in this file (here for Qwen2/2.5-VL, and again for Qwen3 models). This duplication makes the code harder to read and maintain. Please consider refactoring this logic into a helper function that takes the module path and class name as arguments. This would significantly reduce code duplication and improve maintainability.
For example, a helper could look like this:
def _patch_vision_model_for_dp(module_path, class_name, ulysses_sp_size):
try:
module = __import__(module_path, fromlist=[class_name])
model_class = getattr(module, class_name)
original_forward = model_class.forward
model_class.forward = create_dp_vision_forward(original_forward)
print(f"Monkey patch {class_name}.forward for Vision DP (dp_size={ulysses_sp_size})")
except (ImportError, AttributeError) as e:
print(f"Warning: Could not patch {class_name} for Vision DP: {e}")Additionally, the step numbering is inconsistent ("Step 4" here, but "Step 3" for the Qwen3 models). A refactor would help resolve such inconsistencies.
| ) | ||
|
|
||
| # Compute patch offsets for each image | ||
| patch_counts = (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).tolist() |
There was a problem hiding this comment.
The patch_counts are recalculated here, but they have already been computed in the calling function create_dp_vision_forward. This leads to redundant computation. You can improve efficiency and code clarity by passing patch_counts as an argument to prepare_local_vision_inputs and removing this line.
| # This rank has no images, create empty tensor with correct hidden size | ||
| # Try multiple common attribute paths for hidden size detection | ||
| if hasattr(self, "merger") and hasattr(self.merger, "ln_q"): | ||
| ln_q = self.merger.ln_q | ||
| if hasattr(ln_q, "normalized_shape"): | ||
| hidden_size = ln_q.normalized_shape[0] | ||
| elif hasattr(ln_q, "weight"): | ||
| hidden_size = ln_q.weight.shape[0] | ||
| else: | ||
| raise RuntimeError(f"Cannot determine hidden_size from ln_q. Type: {type(ln_q).__name__}") | ||
| elif hasattr(self, "out_hidden_size"): | ||
| hidden_size = self.out_hidden_size | ||
| elif hasattr(self, "config") and hasattr(self.config, "hidden_size"): | ||
| hidden_size = self.config.hidden_size | ||
| else: | ||
| raise RuntimeError( | ||
| f"Cannot determine hidden_size for VisionTransformer. " | ||
| f"Model type: {type(self).__name__}. " | ||
| f"Available attributes: {[a for a in dir(self) if not a.startswith('_')]}" | ||
| ) |
There was a problem hiding this comment.
The logic to determine hidden_size for ranks with no images is complex and relies on a series of hasattr checks for specific model attributes. This approach is brittle and may break with new or refactored models, leading to RuntimeError. A more robust approach would be to make this explicit. For example, create_dp_vision_forward could accept the hidden_size or a hidden_size_getter function as an argument. This would be provided at the patching site where the model context is clear, making the wrapper more reliable and easier to maintain.
When Ulysses sequence parallelism is enabled (sp_size > 1), the VisionTransformer processes all images on every rank redundantly. This adds Vision Data Parallel (DP) which distributes whole images across SP ranks for independent ViT processing, then all-gathers embeddings once at the end — reducing ViT peak memory by ~sp_size x. Also removes the forced eager attention fallback for ViT when SP>1, since Vision DP makes each rank process only its local images and the _ulysses_flash_attention_forward already correctly skips ViT (via position_ids is None guard). Supports Qwen2-VL, Qwen2.5-VL, Qwen3-VL, and Qwen3-VL-MoE. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
76ba823 to
8763b64
Compare
Summary
1/sp_sizeof the images, reducing ViT peak memory by ~sp_sizex (e.g. SP=4 → ~4x ViT memory reduction)ulysses_sp_size > 1, each rank processes a subset of images independently, then all-gathers embeddings once at the endcu_seqlenssemantics (which would happen if patches within images were split across ranks)create_dp_vision_forward()wrapper supports any VisionTransformer withforward(self, hidden_states, grid_thw)signatureGatherVisionEmbeddingscustom autograd function with proper gradient scaling for FSDP/DDP compatibilityWhy this matters
In VLM RL training with Ulysses SP, the ViT (VisionTransformer) is a major memory bottleneck. Text SP splits the sequence across ranks at each attention layer, but the ViT runs on the full set of images on every rank — meaning ViT memory usage is completely unaffected by SP. For scenarios with many images (e.g. multi-turn GUI agent training with screenshots), ViT activation memory can dominate.
Vision DP solves this by distributing images at the ViT level:
O(total_images)total_images/Nimages → ViT memory =O(total_images/N)Key design choices
cu_seqlenstrackingdp_sizeto compensate for partial image processing before FSDP/DDP reductionTest plan
get_image_patch_counts,assign_images_to_dp_ranks,prepare_local_vision_inputs)🤖 Generated with Claude Code