generate.utils

olmo_core.generate.utils.selective_log_softmax(logits, index)[source]

Compute log softmax probabilities for selected tokens.

Note

torch.compile() performs an optimization that avoids materializing the full log softmax tensor when combined with gather operations, which can save significant memory compared to computing the full log softmax and then indexing.

Parameters:
  • logits (Tensor) – The logits tensor of shape (..., vocab_size).

  • index (Tensor) – The index tensor of shape (...).

Return type:

Tensor

Returns:

The log probabilities of shape (...).