Shortcuts

Source code for ignite.metrics.nlp.perplexity

from collections.abc import Callable

import torch
import torch.nn.functional as F

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["Perplexity"]


[docs]class Perplexity(Metric): r"""Calculates the `Perplexity <https://en.wikipedia.org/wiki/Perplexity>`_ of a language model. .. math:: \text{PPL}(W) = \exp \left( -\frac{1}{N} \sum_{i=1}^{N} \log P(w_i | w_1, \ldots, w_{i-1}) \right) where :math:`N` is the total number of tokens and :math:`P(w_i | w_1, \ldots, w_{i-1})` is the conditional probability of token :math:`w_i` given the preceding tokens. Perplexity is computed as :math:`\exp(\text{NLL})` where NLL is the mean negative log-likelihood over all tokens. Lower perplexity indicates a better language model. - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. - `y_pred` must be a floating-point tensor of shape ``(batch_size, vocab_size, seq_len)`` containing the unnormalized log-probabilities (logits). - `y` must be a long tensor of shape ``(batch_size, seq_len)`` containing the target token indices. Note: Perplexity uses token-weighted accumulation rather than batch-average to avoid bias towards shorter sequences. The total NLL and total token count are accumulated across all batches, and the final perplexity is computed as ``exp(total_nll / total_tokens)``. Args: output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. Examples: For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. .. testcode:: from ignite.metrics.nlp import Perplexity import torch ppl = Perplexity() # batch_size=2, vocab_size=5, seq_len=3 y_pred = torch.randn(2, 5, 3) y = torch.randint(0, 5, (2, 3)) ppl.update((y_pred, y)) print(type(ppl.compute())) .. testoutput:: <class 'float'> .. versionadded:: 0.5.5 """ _state_dict_all_req_keys = ("_sum_of_nll", "_num_tokens") def __init__( self, output_transform: Callable = lambda x: x, device: str | torch.device = torch.device("cpu"), ignore_index: int = -100, ): self._ignore_index = ignore_index super().__init__(output_transform=output_transform, device=device)
[docs] @reinit__is_reduced def reset(self) -> None: self._sum_of_nll = torch.tensor(0.0, device=self._device) self._num_tokens = torch.tensor(0, device=self._device)
[docs] @reinit__is_reduced def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: y_pred, y = output y_pred = y_pred.detach() y = y.detach() if y_pred.ndim < 2: raise ValueError(f"y_pred must be at least 2-dimensional (got shape: {y_pred.shape})") if y.ndim < 1: raise ValueError(f"y must be at least 1-dimensional (got shape: {y.shape})") if y_pred.ndim != y.ndim + 1: raise ValueError( f"y_pred must have exactly one more dimension than y (got y_pred={y_pred.ndim}D and y={y.ndim}D)." ) if y_pred.shape[0] != y.shape[0] or y_pred.shape[2:] != y.shape[1:]: raise ValueError( f"y_pred and y have incompatible shapes: " f"y_pred={y_pred.shape}, y={y.shape} " f"(expected y_pred[0] == y[0] and y_pred[2:] == y[1:])." ) nll = F.cross_entropy(y_pred, y, reduction="sum", ignore_index=self._ignore_index) self._sum_of_nll += nll.to(self._device) self._num_tokens += (y != self._ignore_index).sum().to(self._device)
[docs] @sync_all_reduce("_sum_of_nll", "_num_tokens") def compute(self) -> float: if self._num_tokens == 0: raise NotComputableError("Perplexity must have at least one example before it can be computed.") return torch.exp(self._sum_of_nll / self._num_tokens).item()

© Copyright 2026, PyTorch-Ignite Contributors. Last updated on 06/30/2026, 8:38:43 PM.

Built with Sphinx using a theme provided by Read the Docs.
×

Search Docs