generate.sampling

olmo_core.generate.sampling.greedy_selection(logits)[source]

Deterministically select the next token as the one with the highest logit.

Parameters:

logits (Tensor) – Logits tensor of shape (..., vocab_size).

Return type:

Tensor

Returns:

Selected token indices of shape (...).

olmo_core.generate.sampling.top_k_filtering(logits, top_k)[source]

Filter logits to keep only the top k tokens.

Parameters:
  • logits (Tensor) – Logits tensor of shape (..., vocab_size).

  • top_k (int) – Number of top tokens to keep.

Return type:

Tensor

Returns:

Filtered logits with -inf for tokens outside top k.

olmo_core.generate.sampling.top_p_filtering(logits, top_p)[source]

Filter logits using nucleus (top-p) sampling.

Parameters:
  • logits (Tensor) – Logits tensor of shape (..., vocab_size).

  • top_p (float) – Cumulative probability threshold for nucleus sampling.

Return type:

Tensor

Returns:

Filtered logits with -inf for tokens outside the nucleus.

olmo_core.generate.sampling.select_next_token(logits, do_sample=True, temperature=0.0, top_k=-1, top_p=1.0, dtype=torch.float32)[source]

Sample from the logits using temperature scaling with optional top-k and top-p filtering.

Parameters:
  • logits (Tensor) – Logits tensor of shape (..., vocab_size).

  • do_sample (bool, default: True) – Whether to sample from the distribution. If False, uses greedy selection. If True, applies temperature scaling and optional top-k and top-p filtering.

  • temperature (float, default: 0.0) – Temperature for scaling. Higher values increase randomness. Values < 1.0 make the distribution sharper (more deterministic). Values > 1.0 make the distribution flatter (more random). Value = 0.0 is equivalent to greedy selection.

  • top_k (int, default: -1) – Only consider the top k tokens with highest probabilities. -1 means no filtering.

  • top_p (float, default: 1.0) – Only consider the smallest set of tokens whose cumulative probability exceeds this threshold (nucleus sampling). 1.0 means no filtering.

  • dtype (dtype, default: torch.float32) – The dtype of the output tensor. If specified, the input tensor is cast to dtype before the operation is performed. This is useful for preventing data type overflows.

Return type:

Tensor

Returns:

Sampled token indices of shape (...).