Skip to content

Commit f0591ab

Browse files
authored
Apply suggestions from code review
Signed-off-by: Hubert Zhang <[email protected]>
1 parent 1fa937b commit f0591ab

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

checkpoint_engine/worker.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,17 +172,17 @@ def _load_weights(weights: _WEIGHTS_TYPE):
172172
self.model_runner.model.load_weights(weights)
173173
# Load drafter model weights if MTP/speculative decoding is enabled
174174
if (
175-
getattr(self.model_runner, "drafter", None) is not None
176-
and getattr(self.model_runner.drafter, "model", None) is not None
175+
hasattr(self.model_runner, "drafter")
176+
and hasattr(self.model_runner.drafter, "model")
177177
):
178178
self.model_runner.drafter.model.load_weights(weights=weights)
179179

180180
def _post_hook():
181181
process_weights_after_loading(self.model_runner.model, self.model_config, self.device)
182182
# Also trigger drafter model's post processing if MTP is enabled
183183
if (
184-
getattr(self.model_runner, "drafter", None) is not None
185-
and getattr(self.model_runner.drafter, "model", None) is not None
184+
hasattr(self.model_runner, "drafter")
185+
and hasattr(self.model_runner.drafter, "model")
186186
):
187187
process_weights_after_loading(
188188
self.model_runner.drafter.model, self.model_config, self.device

0 commit comments

Comments
 (0)