Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 90 additions & 1 deletion tests/utils/debug/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch

from verl.protocol import DataProto
from verl.utils.debug.metrics import calculate_debug_metrics
from verl.utils.debug.metrics import _find_contiguous_segments, calculate_debug_metrics


class TestMetrics(unittest.TestCase):
Expand All @@ -43,6 +43,95 @@ def test_calculate_debug_metrics(self):
print(metrics)
assert metrics["training/rollout_probs_diff_valid"] == 1

def test_find_contiguous_segments(self):
# Single segment
mask = torch.tensor([1, 1, 1, 0, 0])
assert _find_contiguous_segments(mask) == [(0, 3)]

# Multiple segments (multi-turn)
mask = torch.tensor([1, 1, 0, 0, 1, 1, 1, 0, 1])
assert _find_contiguous_segments(mask) == [(0, 2), (4, 7), (8, 9)]

# All zeros
mask = torch.tensor([0, 0, 0])
assert _find_contiguous_segments(mask) == []

# All ones
mask = torch.tensor([1, 1, 1])
assert _find_contiguous_segments(mask) == [(0, 3)]

def test_per_round_metrics_single_turn(self):
"""Single contiguous response should produce 1 round."""
data = DataProto.from_dict(
{
"rollout_log_probs": torch.tensor([[-1.0, -2.0, -3.0, -4.0]]),
"old_log_probs": torch.tensor([[-1.1, -2.1, -3.1, -4.1]]),
"response_mask": torch.tensor([[1, 1, 1, 1]]),
"responses": torch.zeros((1, 4)),
}
)
metrics = calculate_debug_metrics(data)
assert metrics["per_round/total_rounds"] == 1
assert "per_round/round_0_abs_diff_mean" in metrics
self.assertAlmostEqual(metrics["per_round/round_0_abs_diff_mean"], 0.1, places=5)

def test_per_round_metrics_multi_turn(self):
"""Multi-turn: two rounds separated by env tokens."""
# Round 0: positions 0-1, identical logprobs -> diff=0
# Round 1: positions 4-5, different logprobs -> diff=1.0
data = DataProto.from_dict(
{
"rollout_log_probs": torch.tensor([[-1.0, -2.0, -9.0, -9.0, -3.0, -4.0]]),
"old_log_probs": torch.tensor([[-1.0, -2.0, -9.0, -9.0, -4.0, -5.0]]),
"response_mask": torch.tensor([[1, 1, 0, 0, 1, 1]]),
"responses": torch.zeros((1, 6)),
}
)
metrics = calculate_debug_metrics(data)
assert metrics["per_round/total_rounds"] == 2
# Round 0: identical logprobs
self.assertAlmostEqual(metrics["per_round/round_0_abs_diff_mean"], 0.0, places=5)
# Round 1: diff of 1.0 each
self.assertAlmostEqual(metrics["per_round/round_1_abs_diff_mean"], 1.0, places=5)
# Max diff should be round 1
assert metrics["per_round/max_round_diff"] == 1
self.assertAlmostEqual(metrics["per_round/max_diff_value"], 1.0, places=5)

def test_per_round_metrics_batch(self):
"""Batch with different number of rounds per sample."""
# Sample 0: 1 round (positions 0-2)
# Sample 1: 2 rounds (positions 0-1, positions 3-4)
data = DataProto.from_dict(
{
"rollout_log_probs": torch.tensor(
[
[-1.0, -2.0, -3.0, -9.0, -9.0],
[-1.0, -2.0, -9.0, -3.0, -4.0],
]
),
"old_log_probs": torch.tensor(
[
[-1.0, -2.0, -3.0, -9.0, -9.0],
[-1.0, -2.0, -9.0, -3.5, -4.5],
]
),
"response_mask": torch.tensor(
[
[1, 1, 1, 0, 0],
[1, 1, 0, 1, 1],
]
),
"responses": torch.zeros((2, 5)),
}
)
metrics = calculate_debug_metrics(data)
# Max rounds across batch is 2
assert metrics["per_round/total_rounds"] == 2
assert "per_round/round_0_abs_diff_mean" in metrics
assert "per_round/round_1_abs_diff_mean" in metrics
assert metrics["per_round/round_0_token_count"] == 5 # 3 from sample 0 + 2 from sample 1
assert metrics["per_round/round_1_token_count"] == 2 # only from sample 1


if __name__ == "__main__":
unittest.main()
121 changes: 120 additions & 1 deletion verl/utils/debug/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,118 @@ def calculate_log_prob_diff(log_probs1: torch.Tensor, log_probs2: torch.Tensor,
return torch.masked_select(full_diff, mask)


def _find_contiguous_segments(mask_1d: torch.Tensor) -> list[tuple[int, int]]:
"""Find contiguous segments of 1s in a 1D mask tensor.

Each contiguous segment of 1s represents one round of model generation
in a multi-turn trajectory. Segments are separated by 0s (environment
tokens like images, or padding).

Example:
mask = [1,1,1,1, 0,0,0,0,0, 1,1,1, 0,0,0,0,0, 1,1,1, 0,0,0]
|--R0--| |--env--| |--R1-| |--env--| |--R2-| |pad|
Returns: [(0, 4), (9, 12), (17, 20)]

Args:
mask_1d: 1D tensor with 0s and 1s

Returns:
List of (start, end) tuples for each contiguous segment of 1s.
end is exclusive (Python slice convention).
"""
segments = []
in_segment = False
start = 0

for i in range(len(mask_1d)):
val = mask_1d[i].item() if isinstance(mask_1d[i], torch.Tensor) else mask_1d[i]
if val == 1 and not in_segment:
in_segment = True
start = i
elif val == 0 and in_segment:
in_segment = False
segments.append((start, i))

if in_segment:
segments.append((start, len(mask_1d)))

return segments


def _calculate_per_round_metrics(
train_log_probs: torch.Tensor,
rollout_log_probs: torch.Tensor,
response_mask: torch.Tensor,
) -> dict:
"""Calculate per-round logprob mismatch metrics for multi-turn trajectories.

Identifies rounds by finding contiguous segments of 1s in response_mask,
then computes mean absolute logprob difference per round.

This is useful for multi-turn RL training where different rounds may have
different attention mask behavior (e.g., image window attention), causing
mismatch between training and rollout engines to vary across rounds.

Args:
train_log_probs: Log probs from training engine (batch_size, seq_len)
rollout_log_probs: Log probs from rollout engine (batch_size, seq_len)
response_mask: Mask for valid positions (batch_size, seq_len),
1=model generated token, 0=environment token or padding

Returns:
Dictionary with per-round metrics:
- per_round/total_rounds: Max number of rounds across batch
- per_round/round_{i}_abs_diff_mean: Mean |logprob_train - logprob_rollout| for round i
- per_round/round_{i}_token_count: Number of tokens in round i
- per_round/max_round_diff: Which round has the largest mean diff
- per_round/max_diff_value: The largest mean diff value
"""
batch_size = train_log_probs.shape[0]

# round_idx -> list of (train_vals, rollout_vals)
all_round_data: dict[int, list[tuple[torch.Tensor, torch.Tensor]]] = {}
max_rounds = 0

for b in range(batch_size):
segments = _find_contiguous_segments(response_mask[b])
max_rounds = max(max_rounds, len(segments))

for round_idx, (start, end) in enumerate(segments):
if round_idx not in all_round_data:
all_round_data[round_idx] = []
all_round_data[round_idx].append((train_log_probs[b, start:end], rollout_log_probs[b, start:end]))

if not all_round_data:
return {"per_round/total_rounds": 0}

metrics: dict = {"per_round/total_rounds": max_rounds}
max_diff = -1.0
max_diff_round = -1

for round_idx in sorted(all_round_data.keys()):
train_all = torch.cat([t for t, _ in all_round_data[round_idx]])
rollout_all = torch.cat([r for _, r in all_round_data[round_idx]])

if train_all.numel() == 0:
continue

abs_diff = torch.abs(train_all - rollout_all)
mean_diff = abs_diff.mean().item()

metrics[f"per_round/round_{round_idx}_abs_diff_mean"] = mean_diff
metrics[f"per_round/round_{round_idx}_token_count"] = train_all.numel()

if mean_diff > max_diff:
max_diff = mean_diff
max_diff_round = round_idx

metrics["per_round/max_round_diff"] = max_diff_round
if max_diff_round >= 0:
metrics["per_round/max_diff_value"] = max_diff

return metrics


def calculate_debug_metrics(data: DataProto) -> dict:
"""
calculate rollout vs actor logprobs diff, for debugging purpose
Expand Down Expand Up @@ -100,10 +212,17 @@ def calculate_debug_metrics(data: DataProto) -> dict:
response_mask_bool = response_mask.bool()
pearson_corrcoef = pearson_correlation_coefficient(actor_probs, rollout_probs, response_mask_bool)
rollout_probs_diff = calculate_log_prob_diff(actor_probs, rollout_probs, response_mask_bool)
return {

metrics = {
"training/rollout_probs_diff_valid": 1,
"training/rollout_probs_diff_max": torch.max(rollout_probs_diff).detach().item(),
"training/rollout_probs_diff_mean": torch.mean(rollout_probs_diff).detach().item(),
"training/rollout_probs_diff_std": torch.std(rollout_probs_diff).detach().item(),
"training/rollout_actor_probs_pearson_corr": pearson_corrcoef,
}

# Per-round logprob mismatch metrics for multi-turn trajectories
per_round_metrics = _calculate_per_round_metrics(actor_old_log_probs, rollout_old_log_probs, response_mask)
metrics.update(per_round_metrics)

return metrics