Skip to content

Commit 15446dd

Browse files
youzhedianhongchao
andauthored
feat: support mtp in vllm, update vllm's drafter model when update_weights (#81)
Co-authored-by: hongchao <[email protected]>
1 parent f6910d6 commit 15446dd

File tree

1 file changed

+28
-8
lines changed

1 file changed

+28
-8
lines changed

checkpoint_engine/worker.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid
1111

1212

13+
_WEIGHTS_TYPE = list[tuple[str, torch.Tensor]]
14+
15+
1316
def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
1417
func, args = handle
1518
list_args = list(args)
@@ -29,11 +32,9 @@ class FlattenedTensorMetadata(TypedDict):
2932
offset: int
3033

3134

32-
def _extract_weights(
33-
payload: list[FlattenedTensorMetadata], buffer: torch.Tensor
34-
) -> list[tuple[str, torch.Tensor]]:
35+
def _extract_weights(payload: list[FlattenedTensorMetadata], buffer: torch.Tensor) -> _WEIGHTS_TYPE:
3536
assert buffer is not None
36-
weights: list[tuple[str, torch.Tensor]] = []
37+
weights: _WEIGHTS_TYPE = []
3738
for item in payload:
3839
shape = item["shape"]
3940
if isinstance(shape, list | tuple):
@@ -166,12 +167,31 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
166167
self.device = torch.device(f"npu:{self.local_rank}")
167168
assert self.device is not None
168169

170+
def _load_weights(weights: _WEIGHTS_TYPE):
171+
# Load main model weights
172+
self.model_runner.model.load_weights(weights)
173+
# Load drafter model weights if MTP/speculative decoding is enabled
174+
if (
175+
getattr(self.model_runner, "drafter", None) is not None
176+
and getattr(self.model_runner.drafter, "model", None) is not None
177+
):
178+
self.model_runner.drafter.model.load_weights(weights=weights)
179+
180+
def _post_hook():
181+
process_weights_after_loading(self.model_runner.model, self.model_config, self.device)
182+
# Also trigger drafter model's post processing if MTP is enabled
183+
if (
184+
getattr(self.model_runner, "drafter", None) is not None
185+
and getattr(self.model_runner.drafter, "model", None) is not None
186+
):
187+
process_weights_after_loading(
188+
self.model_runner.drafter.model, self.model_config, self.device
189+
)
190+
169191
update_weights_from_ipc(
170192
self._zmq_ctx,
171193
zmq_handles[self._device_uuid],
172194
device_id=self.device.index,
173-
run=self.model_runner.model.load_weights,
174-
post_hook=lambda: process_weights_after_loading(
175-
self.model_runner.model, self.model_config, self.device
176-
),
195+
run=_load_weights,
196+
post_hook=_post_hook,
177197
)

0 commit comments

Comments
 (0)