Perplexity#
- class ignite.metrics.Perplexity(output_transform=<function Perplexity.<lambda>>, device=device(type='cpu'), ignore_index=-100)[source]#
Calculates the Perplexity of a language model.
where is the total number of tokens and is the conditional probability of token given the preceding tokens.
Perplexity is computed as where NLL is the mean negative log-likelihood over all tokens. Lower perplexity indicates a better language model.
updatemust receive output of the form(y_pred, y)or{'y_pred': y_pred, 'y': y}.y_pred must be a floating-point tensor of shape
(batch_size, vocab_size, seq_len)containing the unnormalized log-probabilities (logits).y must be a long tensor of shape
(batch_size, seq_len)containing the target token indices.
Note
Perplexity uses token-weighted accumulation rather than batch-average to avoid bias towards shorter sequences. The total NLL and total token count are accumulated across all batches, and the final perplexity is computed as
exp(total_nll / total_tokens).- Parameters:
output_transform (Callable) – a callable that is used to transform the
Engine’sprocess_function’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. By default, metrics require the output as(y_pred, y)or{'y_pred': y_pred, 'y': y}.device (str | device) – specifies which device updates are accumulated on. Setting the metric’s device to be the same as your
updatearguments ensures theupdatemethod is non-blocking. By default, CPU.ignore_index (int) –
Examples
For more information on how metric works with
Engine, visit Attach Engine API.from ignite.metrics.nlp import Perplexity import torch ppl = Perplexity() # batch_size=2, vocab_size=5, seq_len=3 y_pred = torch.randn(2, 5, 3) y = torch.randint(0, 5, (2, 3)) ppl.update((y_pred, y)) print(type(ppl.compute()))
<class 'float'>
New in version 0.5.5.
Methods
Computes the metric based on its accumulated state.
Resets the metric to its initial state.
Updates the metric's state using the passed batch output.
- compute()[source]#
Computes the metric based on its accumulated state.
By default, this is called at the end of each epoch.
- Returns:
- the actual quantity of interest. However, if a
Mappingis returned, it will be (shallow) flattened into engine.state.metrics whencompleted()is called. - Return type:
Any
- Raises:
NotComputableError – raised when the metric cannot be computed.
- reset()[source]#
Resets the metric to its initial state.
By default, this is called at the start of each epoch.
- Return type:
None
- update(output)[source]#
Updates the metric’s state using the passed batch output.
By default, this is called once for each batch.
- Parameters:
output (tuple[torch.Tensor, torch.Tensor]) – the is the output from the engine’s process function.
- Return type:
None