Skip to content

Commit 2f78c84

Browse files
committed
pyrefly
Signed-off-by: Yuki Huang <[email protected]>
1 parent e163c04 commit 2f78c84

File tree

6 files changed

+50
-27
lines changed

6 files changed

+50
-27
lines changed

examples/configs/distillation_math.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,12 +231,12 @@ logger:
231231
monitor_gpus: true
232232
wandb:
233233
project: "nemo-distillation"
234-
name: "distillation-${data.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}"
234+
name: "distillation-${data.train.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}"
235235
swanlab:
236236
project: "nemo-distillation"
237-
name: "distillation-${data.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}"
237+
name: "distillation-${data.train.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}"
238238
tensorboard:
239-
log_dir: "tb_logs-distillation-${data.dataset_name}"
239+
log_dir: "tb_logs-distillation-${data.train.dataset_name}"
240240
mlflow:
241241
experiment_name: "distillation-dev"
242242
run_name: "distillation-math-cl-logger"

examples/configs/distillation_math_megatron.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,11 @@ logger:
147147
wandb_enabled: true
148148
wandb:
149149
project: "nemo-distillation"
150-
name: "distillation-megatron-${data.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}"
150+
name: "distillation-megatron-${data.train.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}"
151151
tensorboard:
152-
log_dir: "tb_logs-distillation-megatron-${data.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}"
152+
log_dir: "tb_logs-distillation-megatron-${data.train.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}"
153153
mlflow:
154-
run_name: "distillation-math-megatron-${data.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}"
154+
run_name: "distillation-math-megatron-${data.train.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}"
155155

156156
cluster:
157157
gpus_per_node: 8

examples/configs/sft.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,15 +205,15 @@ logger:
205205
monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard
206206
wandb:
207207
project: "sft-dev"
208-
name: "sft-dev-${data.dataset_name}"
208+
name: "sft-dev-${data.train.dataset_name}"
209209
swanlab:
210210
project: "sft-dev"
211-
name: "sft-dev-${data.dataset_name}"
211+
name: "sft-dev-${data.train.dataset_name}"
212212
tensorboard:
213-
log_dir: "tb_logs-sft-dev-${data.dataset_name}"
213+
log_dir: "tb_logs-sft-dev-${data.train.dataset_name}"
214214
mlflow:
215215
experiment_name: "sft-dev"
216-
run_name: "sft-dev-${data.dataset_name}"
216+
run_name: "sft-dev-${data.train.dataset_name}"
217217
gpu_monitoring:
218218
collection_interval: 10 # How often to collect GPU usage metrics (in seconds)
219219
flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds)

examples/configs/sft_vlm_3B.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ logger:
4343
monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard
4444
wandb:
4545
project: "sft-dev"
46-
name: "sft-dev-${data.dataset_name}"
46+
name: "sft-dev-${data.train.dataset_name}"
4747
tensorboard:
48-
log_dir: "tb_logs-sft-dev-${data.dataset_name}"
48+
log_dir: "tb_logs-sft-dev-${data.train.dataset_name}"
4949
gpu_monitoring:
5050
collection_interval: 10 # How often to collect GPU usage metrics (in seconds)
5151
flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds)

nemo_rl/data/__init__.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,32 +15,49 @@
1515
from typing import Literal, NotRequired, TypedDict
1616

1717

18-
# TODO: split this typed dict up so it can be PreferenceDataConfig | ResponseDataConfig | etc
18+
class ResponseDatasetConfig(TypedDict):
19+
dataset_name: str
20+
data_path: NotRequired[str]
21+
input_key: NotRequired[str]
22+
output_key: NotRequired[str]
23+
split: NotRequired[str]
24+
prompt_file: NotRequired[str | None]
25+
system_prompt_file: NotRequired[str | None]
26+
env_name: NotRequired[str]
27+
download_dir: NotRequired[str]
28+
split_validation_size: NotRequired[float]
29+
30+
31+
# TODO: split this typed dict up so it can be PreferenceDatasetConfig | ResponseDatasetConfig | etc
1932
# so that we can type check the configs more rigorously as opposed to saying everything
2033
# is not required.
2134
class DataConfig(TypedDict):
2235
max_input_seq_length: int
23-
prompt_file: NotRequired[str | None]
24-
system_prompt_file: NotRequired[str | None]
25-
dataset_name: str
26-
val_dataset_name: NotRequired[str]
2736
add_bos: NotRequired[bool]
2837
add_eos: NotRequired[bool]
29-
input_key: NotRequired[str]
30-
output_key: NotRequired[str | None]
3138
add_generation_prompt: NotRequired[bool]
3239
add_system_prompt: NotRequired[bool]
33-
split: NotRequired[str | None]
3440
shuffle: bool
35-
seed: NotRequired[int | None]
36-
download_dir: NotRequired[str]
37-
train_data_path: NotRequired[str]
38-
val_data_paths: NotRequired[dict[str, str]]
3941
# Number of data loader workers.
4042
# Set to 8 or 10 for large batches to improve loading speed.
4143
# This saturates CPU threads without consuming too much memory
4244
# However, setting it too high might cause memory issues for long seqlens.
4345
num_workers: NotRequired[int]
46+
# dataset configs
47+
prompt_file: NotRequired[str | None]
48+
system_prompt_file: NotRequired[str | None]
49+
env_name: NotRequired[str]
50+
# TODO: remove NotRequired once preference dataset is refactored
51+
train: NotRequired[ResponseDatasetConfig]
52+
validation: NotRequired[ResponseDatasetConfig | None]
53+
# TODO: remove once preference dataset is refactored
54+
dataset_name: NotRequired[str]
55+
val_dataset_name: NotRequired[str]
56+
input_key: NotRequired[str]
57+
output_key: NotRequired[str | None]
58+
split: NotRequired[str]
59+
train_data_path: NotRequired[str]
60+
val_data_paths: NotRequired[dict[str, str]]
4461

4562

4663
# ===============================================================================

nemo_rl/data/datasets/response_datasets/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from typing import Any
1616

17+
from nemo_rl.data import ResponseDatasetConfig
1718
from nemo_rl.data.datasets.response_datasets.aime24 import AIME2024Dataset
1819
from nemo_rl.data.datasets.response_datasets.clevr import CLEVRCoGenTDataset
1920
from nemo_rl.data.datasets.response_datasets.dapo_math import (
@@ -37,7 +38,7 @@
3738

3839

3940
# TODO: refactor this to use the new processor interface and RawDataset interface. https://github.com/NVIDIA-NeMo/RL/issues/1552
40-
def load_response_dataset(data_config, seed: int = 42):
41+
def load_response_dataset(data_config: ResponseDatasetConfig, seed: int = 42):
4142
"""Loads response dataset."""
4243
dataset_name = data_config["dataset_name"]
4344

@@ -49,7 +50,9 @@ def load_response_dataset(data_config, seed: int = 42):
4950
elif dataset_name == "tulu3_sft_mixture":
5051
base_dataset: Any = Tulu3SftMixtureDataset(**data_config, seed=seed)
5152
elif dataset_name == "openai_format":
52-
base_dataset: Any = OpenAIFormatDataset(**data_config)
53+
base_dataset: Any = OpenAIFormatDataset(
54+
**data_config # pyrefly: ignore[missing-argument] `data_path` is required for this class
55+
)
5356
# for rl training
5457
elif dataset_name == "OpenMathInstruct-2":
5558
# TODO: also test after SFT updated
@@ -76,7 +79,10 @@ def load_response_dataset(data_config, seed: int = 42):
7679
base_dataset: Any = Geometry3KDataset(**data_config)
7780
# fall back to load from JSON file
7881
elif dataset_name == "ResponseDataset":
79-
base_dataset: Any = ResponseDataset(**data_config, seed=seed)
82+
base_dataset: Any = ResponseDataset(
83+
**data_config, # pyrefly: ignore[missing-argument] `data_path` is required for this class
84+
seed=seed,
85+
)
8086
else:
8187
raise ValueError(
8288
f"Unsupported {dataset_name=}. "

0 commit comments

Comments
 (0)