|
1 | 1 | import gc |
2 | 2 | import traceback |
3 | 3 | from collections.abc import Callable |
| 4 | +from functools import cached_property |
4 | 5 | from typing import TypedDict |
5 | 6 |
|
6 | 7 | import torch |
@@ -117,6 +118,21 @@ class VllmColocateWorkerExtension: |
117 | 118 | `worker_extension_cls` argument when initializing the vLLM worker. |
118 | 119 | """ |
119 | 120 |
|
| 121 | + @cached_property |
| 122 | + def _device_uuid(self) -> str: |
| 123 | + from vllm.platforms import current_platform |
| 124 | + |
| 125 | + if current_platform.device_type == "cuda": |
| 126 | + return current_platform.get_device_uuid(self.device.index) |
| 127 | + elif current_platform.device_type == "npu": |
| 128 | + return f"NPU-{npu_generate_uuid()}" |
| 129 | + else: |
| 130 | + raise ValueError(f"Unsupported device type: {current_platform.device_type}") |
| 131 | + |
| 132 | + @cached_property |
| 133 | + def _zmq_ctx(self) -> zmq.Context: |
| 134 | + return zmq.Context() |
| 135 | + |
120 | 136 | def update_weights_from_ipc(self, zmq_handles: dict[str, str]): |
121 | 137 | """ |
122 | 138 | Update model weights from checkpoint-engine via IPC communication. |
@@ -149,16 +165,6 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]): |
149 | 165 | if current_platform.device_type == "npu" and self.device is None: |
150 | 166 | self.device = torch.device(f"npu:{self.local_rank}") |
151 | 167 | assert self.device is not None |
152 | | - if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None: |
153 | | - self._zmq_ctx = zmq.Context() |
154 | | - |
155 | | - if not hasattr(self, "_device_uuid") or self._device_uuid is None: |
156 | | - if current_platform.device_type == "cuda": |
157 | | - self._device_uuid = current_platform.get_device_uuid(self.device.index) |
158 | | - elif current_platform.device_type == "npu": |
159 | | - self._device_uuid = f"NPU-{npu_generate_uuid()}" |
160 | | - else: |
161 | | - raise ValueError(f"Unsupported device type: {current_platform.device_type}") |
162 | 168 |
|
163 | 169 | update_weights_from_ipc( |
164 | 170 | self._zmq_ctx, |
|
0 commit comments