From 6c16a0594c7c85ce2737a9e43370f2022fc3f019 Mon Sep 17 00:00:00 2001 From: LuggiStruggi Date: Mon, 2 Feb 2026 14:06:53 +0100 Subject: [PATCH 1/2] perplexity + geometric mean --- metrics/perplexity/perplexity.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/metrics/perplexity/perplexity.py b/metrics/perplexity/perplexity.py index 557172cd..fe120d06 100644 --- a/metrics/perplexity/perplexity.py +++ b/metrics/perplexity/perplexity.py @@ -182,11 +182,21 @@ def _compute( shift_labels = labels[..., 1:].contiguous() shift_attention_mask_batch = attn_mask[..., 1:].contiguous() + negative_log_likelihood_batch = ( + loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch + ).sum(1) + perplexity_batch = torch.exp( - (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1) - / shift_attention_mask_batch.sum(1) + negative_log_likelihood_batch / shift_attention_mask_batch.sum(1) ) ppls += perplexity_batch.tolist() - return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)} + nll_sum += negative_log_likelihood_batch.sum().item() + total_tokens += shift_attention_mask_batch.sum().item() + + return { + "perplexities": ppls, + "mean_perplexity": np.mean(ppls), + "geometric_mean_perplexity": np.exp(nll_sum / total_tokens), + } From b9de6e241bf9db7f7f38f119e3451b87def2f74d Mon Sep 17 00:00:00 2001 From: LuggiStruggi Date: Mon, 2 Feb 2026 14:25:22 +0100 Subject: [PATCH 2/2] perplexity + geometric mean --- metrics/perplexity/perplexity.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/metrics/perplexity/perplexity.py b/metrics/perplexity/perplexity.py index fe120d06..a2b5ccb9 100644 --- a/metrics/perplexity/perplexity.py +++ b/metrics/perplexity/perplexity.py @@ -161,6 +161,9 @@ def _compute( ppls = [] loss_fct = CrossEntropyLoss(reduction="none") + nll_sum = 0.0 + total_tokens = 0 + for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)): end_index = min(start_index + batch_size, len(encoded_texts)) encoded_batch = encoded_texts[start_index:end_index]