1010from checkpoint_engine .device_utils import DeviceManager , npu_generate_uuid
1111
1212
13+ _WEIGHTS_TYPE = list [tuple [str , torch .Tensor ]]
14+
15+
1316def _rebuild_ipc (handle : tuple [Callable , tuple ], device_id : int | None = None ) -> torch .Tensor :
1417 func , args = handle
1518 list_args = list (args )
@@ -29,11 +32,9 @@ class FlattenedTensorMetadata(TypedDict):
2932 offset : int
3033
3134
32- def _extract_weights (
33- payload : list [FlattenedTensorMetadata ], buffer : torch .Tensor
34- ) -> list [tuple [str , torch .Tensor ]]:
35+ def _extract_weights (payload : list [FlattenedTensorMetadata ], buffer : torch .Tensor ) -> _WEIGHTS_TYPE :
3536 assert buffer is not None
36- weights : list [ tuple [ str , torch . Tensor ]] = []
37+ weights : _WEIGHTS_TYPE = []
3738 for item in payload :
3839 shape = item ["shape" ]
3940 if isinstance (shape , list | tuple ):
@@ -166,12 +167,31 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
166167 self .device = torch .device (f"npu:{ self .local_rank } " )
167168 assert self .device is not None
168169
170+ def _load_weights (weights : _WEIGHTS_TYPE ):
171+ # Load main model weights
172+ self .model_runner .model .load_weights (weights )
173+ # Load drafter model weights if MTP/speculative decoding is enabled
174+ if (
175+ getattr (self .model_runner , "drafter" , None ) is not None
176+ and getattr (self .model_runner .drafter , "model" , None ) is not None
177+ ):
178+ self .model_runner .drafter .model .load_weights (weights = weights )
179+
180+ def _post_hook ():
181+ process_weights_after_loading (self .model_runner .model , self .model_config , self .device )
182+ # Also trigger drafter model's post processing if MTP is enabled
183+ if (
184+ getattr (self .model_runner , "drafter" , None ) is not None
185+ and getattr (self .model_runner .drafter , "model" , None ) is not None
186+ ):
187+ process_weights_after_loading (
188+ self .model_runner .drafter .model , self .model_config , self .device
189+ )
190+
169191 update_weights_from_ipc (
170192 self ._zmq_ctx ,
171193 zmq_handles [self ._device_uuid ],
172194 device_id = self .device .index ,
173- run = self .model_runner .model .load_weights ,
174- post_hook = lambda : process_weights_after_loading (
175- self .model_runner .model , self .model_config , self .device
176- ),
195+ run = _load_weights ,
196+ post_hook = _post_hook ,
177197 )
0 commit comments