diff --git a/metrics/perplexity/perplexity.py b/metrics/perplexity/perplexity.py index 557172cd..a2b5ccb9 100644 --- a/metrics/perplexity/perplexity.py +++ b/metrics/perplexity/perplexity.py @@ -161,6 +161,9 @@ def _compute( ppls = [] loss_fct = CrossEntropyLoss(reduction="none") + nll_sum = 0.0 + total_tokens = 0 + for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)): end_index = min(start_index + batch_size, len(encoded_texts)) encoded_batch = encoded_texts[start_index:end_index] @@ -182,11 +185,21 @@ def _compute( shift_labels = labels[..., 1:].contiguous() shift_attention_mask_batch = attn_mask[..., 1:].contiguous() + negative_log_likelihood_batch = ( + loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch + ).sum(1) + perplexity_batch = torch.exp( - (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1) - / shift_attention_mask_batch.sum(1) + negative_log_likelihood_batch / shift_attention_mask_batch.sum(1) ) ppls += perplexity_batch.tolist() - return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)} + nll_sum += negative_log_likelihood_batch.sum().item() + total_tokens += shift_attention_mask_batch.sum().item() + + return { + "perplexities": ppls, + "mean_perplexity": np.mean(ppls), + "geometric_mean_perplexity": np.exp(nll_sum / total_tokens), + }