Skip to content

[trainer] feat: add per-round logprob mismatch metrics for multi-turn training#5229

Open
aoshen524 wants to merge 1 commit intoverl-project:mainfrom
aoshen524:feat/per-round-logprob-mismatch-metrics
Open

[trainer] feat: add per-round logprob mismatch metrics for multi-turn training#5229
aoshen524 wants to merge 1 commit intoverl-project:mainfrom
aoshen524:feat/per-round-logprob-mismatch-metrics

Conversation

@aoshen524
Copy link
Contributor

@aoshen524 aoshen524 commented Feb 7, 2026

What does this PR do?

Add per-round logprob mismatch metrics for multi-turn RL training. In multi-turn trajectories, the response_mask contains contiguous segments of 1s for each round of model generation, separated by 0s for environment tokens (e.g., images). This PR detects those segments and computes per-round mean absolute logprob difference between rollout and actor, making it easy to identify which round diverges most.

This extends the existing debug metrics (#1712, #2808) without changing the existing API or adding any new dependencies.

Checklist Before Starting

Test

Existing unit test passes unchanged:

tests/utils/debug/test_metrics.py::TestMetrics::test_calculate_debug_metrics PASSED

Tested with multi-turn VLM RL training (Qwen2.5-VL) on 8-turn GUI agent trajectories.

API and Usage Example

No API changes. calculate_debug_metrics(data) signature is unchanged. The returned dict now includes additional per_round/ prefixed keys:

from verl.utils.debug.metrics import calculate_debug_metrics

metrics = calculate_debug_metrics(batch)
# Existing keys still present:
#   training/rollout_probs_diff_mean, training/rollout_probs_diff_max, etc.
#
# New per-round keys (only meaningful for multi-turn):
#   per_round/total_rounds: 8
#   per_round/round_0_abs_diff_mean: 0.000123
#   per_round/round_0_token_count: 304
#   ...
#   per_round/round_7_abs_diff_mean: 0.045678
#   per_round/round_7_token_count: 376
#   per_round/max_round_diff: 7
#   per_round/max_diff_value: 0.045678

Design & Code Changes

verl/utils/debug/metrics.py (+120 lines, single file change):

  • _find_contiguous_segments(mask_1d): Finds contiguous segments of 1s in response_mask to identify round boundaries
  • _calculate_per_round_metrics(train_log_probs, rollout_log_probs, response_mask): Computes mean absolute logprob diff per round, aggregated across the batch
  • calculate_debug_metrics(): Now calls _calculate_per_round_metrics and includes the results

Checklist Before Submitting

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive set of metrics for debugging training-inference mismatch in multi-turn RL scenarios. The changes are primarily in verl/utils/debug/metrics.py, with a minor change in verl/trainer/ppo/ray_trainer.py to enable the new metrics. The new metrics include perplexity, KL divergence, importance weights, and more, with per-round analysis. The code is well-structured and includes good documentation. I found one issue where a newly implemented metric function is not being used, which I've commented on.

metrics.update(calculate_chi_squared_metrics(train_log_probs, rollout_log_probs, mask_tensor))
metrics.update(calculate_log_prob_distribution_metrics(train_log_probs, rollout_log_probs, mask_tensor))
metrics.update(calculate_catastrophic_mismatch_metrics(train_log_probs, rollout_log_probs, mask_tensor))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function calculate_sequence_level_mismatch_metrics is defined but never called within calculate_debug_metrics. This seems like an oversight, as the PR description mentions adding sequence-level metrics. You should add a call to this function to include its metrics in the output.

Suggested change
metrics.update(calculate_sequence_level_mismatch_metrics(train_log_probs, rollout_log_probs, mask_tensor))

@aoshen524 aoshen524 force-pushed the feat/per-round-logprob-mismatch-metrics branch from 43b0fc2 to 8e2a0ae Compare February 7, 2026 13:44
… training

Add per-round analysis of rollout vs actor logprob mismatch for multi-turn
RL training. Uses response_mask to identify contiguous segments (rounds)
and computes mean absolute logprob difference per round.

This helps diagnose which conversation round contributes most to the
training-inference mismatch in multi-turn scenarios.
@aoshen524 aoshen524 force-pushed the feat/per-round-logprob-mismatch-metrics branch from 8e2a0ae to 63b63f3 Compare February 7, 2026 13:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant