Skip to content

Commit 33ea56d

Browse files
committed
lint
Signed-off-by: Yuki Huang <yukih@nvidia.com>
1 parent 1bcdeef commit 33ea56d

File tree

3 files changed

+9
-10
lines changed

3 files changed

+9
-10
lines changed

nemo_rl/models/policy/workers/megatron_policy_worker.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2062,7 +2062,9 @@ def _iter_params_with_optional_kv_scales(
20622062
This helper is used by both IPC-based streaming and collective broadcast
20632063
so that the logic for adding KV scales stays consistent in one place.
20642064
"""
2065-
from nemo_rl.models.generation.vllm.quantization.fp8_train_utils import get_vllm_qkv_scale_names
2065+
from nemo_rl.models.generation.vllm.quantization.fp8_train_utils import (
2066+
get_vllm_qkv_scale_names,
2067+
)
20662068

20672069
base_iter = self.megatron_bridge.export_hf_weights(
20682070
[self.model],
@@ -2469,7 +2471,9 @@ def calibrate_qkv_fp8_scales(
24692471
{ "format": "fp8", "percentile": float, "margin": float,
24702472
"layers": { layer_name: {"k_scale": float, "v_scale": float[, "q_scale": float] } } }
24712473
"""
2472-
from nemo_rl.models.generation.vllm.quantization.fp8_train_utils import convert_calibration_to_vllm_format
2474+
from nemo_rl.models.generation.vllm.quantization.fp8_train_utils import (
2475+
convert_calibration_to_vllm_format,
2476+
)
24732477

24742478
# Allow overriding FP8 max for Q, K, V via environment variables for ease of testing.
24752479
# Defaults align with FP8 e4m3 max magnitude.

pyproject.toml

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,7 @@ dependencies = [
5252
]
5353

5454
[project.optional-dependencies]
55-
fsdp = [
56-
"flash-attn==2.8.1",
57-
"mamba-ssm",
58-
"causal-conv1d",
59-
]
55+
fsdp = ["flash-attn==2.8.1", "mamba-ssm", "causal-conv1d"]
6056
automodel = [
6157
"nemo-automodel",
6258
# Flash-attn version should be selected to satisfy both TE + vLLM requirements (xformers in particular)
@@ -77,9 +73,7 @@ vllm = [
7773
"num2words>=0.5.14",
7874
]
7975
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved
80-
vllm_for_train = [
81-
"vllm==0.11.0",
82-
]
76+
vllm_for_train = ["vllm==0.11.0"]
8377
mcore = [
8478
# also need cudnn (https://developer.nvidia.com/cudnn-downloads?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=deb_network)
8579
# wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.1-1_all.deb

pyrefly.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ project-includes = [
101101
"nemo_rl/models/generation/interfaces.py",
102102
"nemo_rl/models/generation/vllm/__init__.py",
103103
"nemo_rl/models/generation/vllm/config.py",
104+
"nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py",
104105
"nemo_rl/models/generation/vllm/utils.py",
105106
"nemo_rl/models/generation/vllm/vllm_backend.py",
106107
"nemo_rl/models/huggingface/__init__.py",

0 commit comments

Comments
 (0)