Skip to content

Commit d811caf

Browse files
authored
Merge branch 'main' into chtruong/publish-docs
2 parents 3982658 + 748b9ca commit d811caf

File tree

19 files changed

+2256
-563
lines changed

19 files changed

+2256
-563
lines changed

docs/fp8.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ This module provides a suite of tools to enable FP8 quantization for large langu
1111
- Uses **TransformerEngine** for linear layer implementation.
1212
- Supports both **Deepseek-style sub-channel scaling** and **per-tensor scaling**.
1313

14+
### Recommended recipe
15+
- For Hopper GPUs we recommend to use FP8 (Deepseek-style) precision for both generation and training for best convergence and speedup
16+
- For Blackwell GPUs, FP8 (deepseek-style) with FP32 scaling factor is not supported in training. Currently we recommend to use FP8 precision for generation and BF16 for training. We are actively exploring other recipes for better performance.
17+
1418
## Integration with NeMo RL
1519

1620
NeMo RL applies monkey patches to several core `vLLM` components to enable FP8 generation for reinforcement learning.

docs/guides/ft-launcher-guide.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Fault Tolerance Launcher Guide
2+
3+
The `ft_launcher` is provided by `nvidia-resiliency-ext` (included in NeMo RL dependencies) and enables automatic fault tolerance and recovery for distributed training runs.
4+
5+
## Key Arguments
6+
7+
| Argument | Description | Example |
8+
|----------|-------------|---------|
9+
| `--ft-cfg-path` | Path to FT YAML config file | `examples/ft_launcher/ft_config.yaml` |
10+
| `--ft-rank-heartbeat-timeout` | Heartbeat timeout in seconds | `450` |
11+
| `--ft-initial-rank-heartbeat-timeout` | Initial timeout (longer for setup) | `1200` |
12+
| `--max-restarts` | Maximum number of restart attempts | `5` |
13+
14+
## Basic Usage
15+
16+
```bash
17+
uv run ft_launcher \
18+
--ft-cfg-path examples/ft_launcher/ft_config.yaml \
19+
--ft-rank-heartbeat-timeout 450 \
20+
--ft-initial-rank-heartbeat-timeout 1200 \
21+
--max-restarts 5 \
22+
examples/run_grpo_math.py \
23+
--config <your_config.yaml>
24+
```
25+
26+
## FT Config File (examples/ft_launcher/ft_config.yaml)
27+
28+
```yaml
29+
fault_tolerance:
30+
initial_rank_heartbeat_timeout: 360
31+
restart_policy: any-failed
32+
```
33+
34+
## Important Notes
35+
36+
1. **Checkpointing**: Enable checkpointing for recovery to work:
37+
```bash
38+
++checkpointing.enabled=true
39+
++checkpointing.checkpoint_dir=/path/to/checkpoints
40+
++checkpointing.save_period=50
41+
```
42+
43+
2. **Timeouts**: Set `--ft-initial-rank-heartbeat-timeout` higher than `--ft-rank-heartbeat-timeout` to allow for model loading/setup time.
44+
45+
3. **Restart Policy**: The `any-failed` restart policy will restart the entire job if any rank fails. Look for these log messages to identify when a restart occurs:
46+
47+
```
48+
[ERROR] [ft_launcher...] failed (exitcode: 1) local_rank: 0 (pid: ...) of binary: ...
49+
[INFO] [ft_launcher...] [default] Worker group FAILED. 3/5 attempts left; will restart worker group
50+
[INFO] [ft_launcher...] Stopping workers... Timeout = 30 sec.
51+
[INFO] [ft_launcher...] The node '...' attempts to join the next round of the rendezvous '...'.
52+
[INFO] [ft_launcher...] The node '...' has joined round N of the rendezvous '...' as rank 0 in a world of size 1.
53+
```
54+
55+
Key indicators:
56+
- `Worker group FAILED. X/Y attempts left` - shows a restart is happening and remaining attempts
57+
- `will restart worker group` - confirms restart is in progress
58+
- `has joined round N` - the round number increases with each restart

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ guides/deepseek.md
219219
model-quirks.md
220220
guides/async-grpo.md
221221
guides/dtensor-tp-accuracy.md
222+
guides/ft-launcher-guide.md
222223
```
223224

224225
```{toctree}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
fault_tolerance:
2+
initial_rank_heartbeat_timeout: 360
3+
restart_policy: any-failed

nemo_rl/algorithms/distillation.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,9 @@ def distillation_train(
695695
print("▶ Computing teacher logprobs...", flush=True)
696696
with timer.time("teacher_logprob_inference"):
697697
teacher_topk = teacher_policy.get_topk_logits(
698-
train_data, k=master_config["distillation"]["topk_logits_k"]
698+
train_data,
699+
k=master_config["distillation"]["topk_logits_k"],
700+
timer=timer,
699701
)
700702
train_data["teacher_topk_logits"] = teacher_topk["topk_logits"]
701703
train_data["teacher_topk_indices"] = teacher_topk["topk_indices"]
@@ -708,7 +710,11 @@ def distillation_train(
708710

709711
print("▶ Training policy...", flush=True)
710712
with timer.time("policy_training"):
711-
train_results = student_policy.train(train_data, loss_fn)
713+
train_results = student_policy.train(
714+
train_data,
715+
loss_fn,
716+
timer=timer,
717+
)
712718

713719
is_last_step = (total_steps + 1 >= max_steps) or (
714720
(current_epoch + 1 == max_epochs)

nemo_rl/algorithms/dpo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,7 @@ def dpo_train(
572572
## examples, chosen and rejected, and the pair needs to be processed as part of the same microbatch.
573573
gbs=master_config["policy"]["train_global_batch_size"] * 2,
574574
mbs=master_config["policy"]["train_micro_batch_size"] * 2,
575+
timer=timer,
575576
)
576577

577578
is_last_step = total_steps + 1 >= master_config["dpo"][

nemo_rl/algorithms/grpo.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,17 +1516,18 @@ def grpo_train(
15161516
**extra_multimodal_data,
15171517
}
15181518
)
1519-
train_data["prev_logprobs"] = policy.get_logprobs(logprob_data)[
1520-
"logprobs"
1521-
]
1519+
train_data["prev_logprobs"] = policy.get_logprobs(
1520+
logprob_data, timer=timer
1521+
)["logprobs"]
15221522

15231523
if not master_config["grpo"].get(
15241524
"skip_reference_policy_logprobs_calculation"
15251525
):
15261526
train_data["reference_policy_logprobs"] = (
1527-
policy.get_reference_policy_logprobs(logprob_data)[
1528-
"reference_logprobs"
1529-
]
1527+
policy.get_reference_policy_logprobs(
1528+
logprob_data,
1529+
timer=timer,
1530+
)["reference_logprobs"]
15301531
)
15311532

15321533
del logprob_data
@@ -1540,7 +1541,11 @@ def grpo_train(
15401541

15411542
print("▶ Training policy...", flush=True)
15421543
with timer.time("policy_training"):
1543-
train_results = policy.train(train_data, loss_fn)
1544+
train_results = policy.train(
1545+
train_data,
1546+
loss_fn,
1547+
timer=timer,
1548+
)
15441549

15451550
# Recompute KV scales after policy training if needed
15461551
if sync_kv_scales:
@@ -2510,9 +2515,13 @@ def async_grpo_train(
25102515

25112516
print("▶ Computing logprobs...")
25122517
with timer.time("policy_and_reference_logprobs"):
2513-
fprop_logprobs = policy.get_logprobs(train_data)["logprobs"]
2518+
fprop_logprobs = policy.get_logprobs(
2519+
train_data,
2520+
timer=timer,
2521+
)["logprobs"]
25142522
reference_logprobs = policy.get_reference_policy_logprobs(
2515-
train_data
2523+
train_data,
2524+
timer=timer,
25162525
)["reference_logprobs"]
25172526
train_data["prev_logprobs"] = fprop_logprobs
25182527
train_data["reference_policy_logprobs"] = reference_logprobs
@@ -2524,7 +2533,11 @@ def async_grpo_train(
25242533

25252534
print("▶ Training policy...")
25262535
with timer.time("policy_training"):
2527-
train_results = policy.train(train_data, loss_fn)
2536+
train_results = policy.train(
2537+
train_data,
2538+
loss_fn,
2539+
timer=timer,
2540+
)
25282541

25292542
print("🔄 Synchronizing policy weights to trajectory collector…")
25302543
generation_logger_metrics = None

nemo_rl/algorithms/rm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ def validate_one_dataset(
343343
# NOTE: we double the batch size because each preference example corresponds to a pair of
344344
# examples, chosen and rejected, and the pair needs to be processed as part of the same microbatch.
345345
mbs=val_mbs * 2,
346+
timer=timer,
346347
)
347348

348349
if len(val_results["all_mb_metrics"]) == 0:
@@ -503,6 +504,7 @@ def rm_train(
503504
## examples, chosen and rejected, and the pair needs to be processed as part of the same microbatch.
504505
gbs=master_config["policy"]["train_global_batch_size"] * 2,
505506
mbs=master_config["policy"]["train_micro_batch_size"] * 2,
507+
timer=timer,
506508
)
507509

508510
is_last_step = (

nemo_rl/algorithms/sft.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,11 @@ def sft_train(
452452

453453
print("▶ Taking a training step...")
454454
with timer.time("policy_training"):
455-
train_results = policy.train(train_data, loss_fn)
455+
train_results = policy.train(
456+
train_data,
457+
loss_fn,
458+
timer=timer,
459+
)
456460

457461
is_last_step = total_steps + 1 >= master_config["sft"][
458462
"max_num_steps"

nemo_rl/data/datasets/utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,23 @@ def pil_to_base64(image: Image.Image, format: str = "PNG") -> str:
6363

6464

6565
def load_dataset_from_path(data_path: str, data_split: Optional[str] = "train"):
66-
"""Load a dataset from a json, huggingface dataset, or Arrow dataset (saved with save_to_disk).
66+
"""Load a dataset from a local file, huggingface dataset, or Arrow dataset (saved with save_to_disk).
6767
6868
Args:
6969
data_path: The path to the dataset.
7070
data_split: The split to load from the dataset.
7171
"""
72+
FILEEXT2TYPE = {
73+
".arrow": "arrow",
74+
".csv": "csv",
75+
".json": "json",
76+
".jsonl": "json",
77+
".parquet": "parquet",
78+
".txt": "text",
79+
}
7280
suffix = os.path.splitext(data_path)[-1]
73-
if suffix in [".json", ".jsonl"]:
74-
raw_dataset = load_dataset("json", data_files=data_path)
81+
if dataset_type := FILEEXT2TYPE.get(suffix):
82+
raw_dataset = load_dataset(dataset_type, data_files=data_path)
7583
else:
7684
try:
7785
raw_dataset = load_dataset(data_path)

0 commit comments

Comments
 (0)