diff --git a/modelopt/torch/quantization/plugins/vllm.py b/modelopt/torch/quantization/plugins/vllm.py index fef38093bf..107e225cd0 100644 --- a/modelopt/torch/quantization/plugins/vllm.py +++ b/modelopt/torch/quantization/plugins/vllm.py @@ -47,7 +47,12 @@ except ImportError: continue -if importlib.util.find_spec("vllm.attention.layers"): # vllm < 0.15.0 +try: + _has_attention_layers = importlib.util.find_spec("vllm.attention.layers") is not None +except (ModuleNotFoundError, ValueError): + _has_attention_layers = False + +if _has_attention_layers: # vllm < 0.15.0 from vllm.attention.layers.cross_attention import CrossAttention from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention else: @@ -60,12 +65,17 @@ except ImportError: EncoderOnlyAttention = None -if importlib.util.find_spec("vllm.attention.layer"): +try: + _has_attention_layer = importlib.util.find_spec("vllm.attention.layer") is not None +except (ModuleNotFoundError, ValueError): + _has_attention_layer = False + +if _has_attention_layer: import vllm.attention.layer as vllm_attention try: VllmMLAAttention = vllm_attention.MLAAttention -except ImportError: +except (AttributeError, ImportError): VllmMLAAttention = None _ATTENTION_TYPES = tuple( @@ -131,12 +141,16 @@ def _get_device_dtype(module: torch.nn.Module) -> tuple: return dev, dt_resolved # KV-cache tensors are available after allocation; respect kv_cache_dtype when set. + # kv_cache is a list of tensors (v0) or a single tensor (v1). kv = getattr(module, "kv_cache", None) - if kv and kv[0] is not None: - t0 = kv[0] - spec = getattr(module, "kv_cache_dtype", t0.dtype) - out_dtype = t0.dtype if spec == "auto" else (_vllm_attr_dtype_to_torch(spec) or t0.dtype) - return t0.device, out_dtype + if kv is not None: + t0 = kv[0] if isinstance(kv, (list, tuple)) and len(kv) > 0 else kv + if isinstance(t0, torch.Tensor) and t0.numel() > 0: + spec = getattr(module, "kv_cache_dtype", t0.dtype) + out_dtype = ( + t0.dtype if spec == "auto" else (_vllm_attr_dtype_to_torch(spec) or t0.dtype) + ) + return t0.device, out_dtype # Shallow scan: weights often live on child modules rather than the attention module itself. for mod in (module, *module.children()): @@ -223,7 +237,11 @@ def create_parallel_state(): """Create a parallel state for vLLM.""" dp_group = get_dp_group().device_group tp_group = get_tp_group().device_group - ep_group = get_ep_group().device_group + try: + # EP group is only created for MoE models; dense models don't have one. + ep_group = get_ep_group().device_group + except (AssertionError, RuntimeError): + ep_group = -1 return ParallelState(dp_group, tp_group, ep_group) @@ -426,14 +444,18 @@ def modelopt_post_restore(self, prefix: str = "") -> None: _vllm_attention_modelopt_post_restore(self) -@QuantModuleRegistry.register({CrossAttention: "vllm_CrossAttention"}) -class _QuantVLLMCrossAttention(_QuantVLLMAttention): - pass +if CrossAttention is not None: + @QuantModuleRegistry.register({CrossAttention: "vllm_CrossAttention"}) + class _QuantVLLMCrossAttention(_QuantVLLMAttention): + pass -@QuantModuleRegistry.register({EncoderOnlyAttention: "vllm_EncoderOnlyAttention"}) -class _QuantVLLMEncoderOnlyAttention(_QuantVLLMAttention): - pass + +if EncoderOnlyAttention is not None: + + @QuantModuleRegistry.register({EncoderOnlyAttention: "vllm_EncoderOnlyAttention"}) + class _QuantVLLMEncoderOnlyAttention(_QuantVLLMAttention): + pass if VllmMLAAttention is not None: