[trainer] feat: add per-round logprob mismatch metrics for multi-turn training#5229
[trainer] feat: add per-round logprob mismatch metrics for multi-turn training#5229aoshen524 wants to merge 1 commit intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a comprehensive set of metrics for debugging training-inference mismatch in multi-turn RL scenarios. The changes are primarily in verl/utils/debug/metrics.py, with a minor change in verl/trainer/ppo/ray_trainer.py to enable the new metrics. The new metrics include perplexity, KL divergence, importance weights, and more, with per-round analysis. The code is well-structured and includes good documentation. I found one issue where a newly implemented metric function is not being used, which I've commented on.
verl/utils/debug/metrics.py
Outdated
| metrics.update(calculate_chi_squared_metrics(train_log_probs, rollout_log_probs, mask_tensor)) | ||
| metrics.update(calculate_log_prob_distribution_metrics(train_log_probs, rollout_log_probs, mask_tensor)) | ||
| metrics.update(calculate_catastrophic_mismatch_metrics(train_log_probs, rollout_log_probs, mask_tensor)) | ||
|
|
There was a problem hiding this comment.
The function calculate_sequence_level_mismatch_metrics is defined but never called within calculate_debug_metrics. This seems like an oversight, as the PR description mentions adding sequence-level metrics. You should add a call to this function to include its metrics in the output.
| metrics.update(calculate_sequence_level_mismatch_metrics(train_log_probs, rollout_log_probs, mask_tensor)) |
43b0fc2 to
8e2a0ae
Compare
… training Add per-round analysis of rollout vs actor logprob mismatch for multi-turn RL training. Uses response_mask to identify contiguous segments (rounds) and computes mean absolute logprob difference per round. This helps diagnose which conversation round contributes most to the training-inference mismatch in multi-turn scenarios.
8e2a0ae to
63b63f3
Compare
What does this PR do?
Add per-round logprob mismatch metrics for multi-turn RL training. In multi-turn trajectories, the
response_maskcontains contiguous segments of 1s for each round of model generation, separated by 0s for environment tokens (e.g., images). This PR detects those segments and computes per-round mean absolute logprob difference between rollout and actor, making it easy to identify which round diverges most.This extends the existing debug metrics (#1712, #2808) without changing the existing API or adding any new dependencies.
Checklist Before Starting
[{modules}] {type}: {description}Test
Existing unit test passes unchanged:
Tested with multi-turn VLM RL training (Qwen2.5-VL) on 8-turn GUI agent trajectories.
API and Usage Example
No API changes.
calculate_debug_metrics(data)signature is unchanged. The returned dict now includes additionalper_round/prefixed keys:Design & Code Changes
verl/utils/debug/metrics.py(+120 lines, single file change):_find_contiguous_segments(mask_1d): Finds contiguous segments of 1s in response_mask to identify round boundaries_calculate_per_round_metrics(train_log_probs, rollout_log_probs, response_mask): Computes mean absolute logprob diff per round, aggregated across the batchcalculate_debug_metrics(): Now calls_calculate_per_round_metricsand includes the resultsChecklist Before Submitting
ci-requestchannel