Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 37 additions & 15 deletions modelopt/torch/quantization/plugins/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@
except ImportError:
continue

if importlib.util.find_spec("vllm.attention.layers"): # vllm < 0.15.0
try:
_has_attention_layers = importlib.util.find_spec("vllm.attention.layers") is not None
except (ModuleNotFoundError, ValueError):
_has_attention_layers = False

if _has_attention_layers: # vllm < 0.15.0
from vllm.attention.layers.cross_attention import CrossAttention
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
else:
Expand All @@ -60,12 +65,17 @@
except ImportError:
EncoderOnlyAttention = None

if importlib.util.find_spec("vllm.attention.layer"):
try:
_has_attention_layer = importlib.util.find_spec("vllm.attention.layer") is not None
except (ModuleNotFoundError, ValueError):
_has_attention_layer = False

if _has_attention_layer:
import vllm.attention.layer as vllm_attention

try:
VllmMLAAttention = vllm_attention.MLAAttention
except ImportError:
except (AttributeError, ImportError):
VllmMLAAttention = None

_ATTENTION_TYPES = tuple(
Expand Down Expand Up @@ -131,12 +141,16 @@ def _get_device_dtype(module: torch.nn.Module) -> tuple:
return dev, dt_resolved

# KV-cache tensors are available after allocation; respect kv_cache_dtype when set.
# kv_cache is a list of tensors (v0) or a single tensor (v1).
kv = getattr(module, "kv_cache", None)
if kv and kv[0] is not None:
t0 = kv[0]
spec = getattr(module, "kv_cache_dtype", t0.dtype)
out_dtype = t0.dtype if spec == "auto" else (_vllm_attr_dtype_to_torch(spec) or t0.dtype)
return t0.device, out_dtype
if kv is not None:
t0 = kv[0] if isinstance(kv, (list, tuple)) and len(kv) > 0 else kv
if isinstance(t0, torch.Tensor) and t0.numel() > 0:
spec = getattr(module, "kv_cache_dtype", t0.dtype)
out_dtype = (
t0.dtype if spec == "auto" else (_vllm_attr_dtype_to_torch(spec) or t0.dtype)
)
return t0.device, out_dtype

# Shallow scan: weights often live on child modules rather than the attention module itself.
for mod in (module, *module.children()):
Expand Down Expand Up @@ -223,7 +237,11 @@ def create_parallel_state():
"""Create a parallel state for vLLM."""
dp_group = get_dp_group().device_group
tp_group = get_tp_group().device_group
ep_group = get_ep_group().device_group
try:
# EP group is only created for MoE models; dense models don't have one.
ep_group = get_ep_group().device_group
except (AssertionError, RuntimeError):
ep_group = -1
return ParallelState(dp_group, tp_group, ep_group)


Expand Down Expand Up @@ -426,14 +444,18 @@ def modelopt_post_restore(self, prefix: str = "") -> None:
_vllm_attention_modelopt_post_restore(self)


@QuantModuleRegistry.register({CrossAttention: "vllm_CrossAttention"})
class _QuantVLLMCrossAttention(_QuantVLLMAttention):
pass
if CrossAttention is not None:

@QuantModuleRegistry.register({CrossAttention: "vllm_CrossAttention"})
class _QuantVLLMCrossAttention(_QuantVLLMAttention):
pass

@QuantModuleRegistry.register({EncoderOnlyAttention: "vllm_EncoderOnlyAttention"})
class _QuantVLLMEncoderOnlyAttention(_QuantVLLMAttention):
pass

if EncoderOnlyAttention is not None:

@QuantModuleRegistry.register({EncoderOnlyAttention: "vllm_EncoderOnlyAttention"})
class _QuantVLLMEncoderOnlyAttention(_QuantVLLMAttention):
pass


if VllmMLAAttention is not None:
Expand Down
Loading