generate.sampling¶
- olmo_core.generate.sampling.greedy_selection(logits)[source]¶
Deterministically select the next token as the one with the highest logit.
- olmo_core.generate.sampling.top_k_filtering(logits, top_k)[source]¶
Filter logits to keep only the top k tokens.
- olmo_core.generate.sampling.top_p_filtering(logits, top_p)[source]¶
Filter logits using nucleus (top-p) sampling.
- olmo_core.generate.sampling.select_next_token(logits, do_sample=True, temperature=0.0, top_k=-1, top_p=1.0, dtype=torch.float32)[source]¶
Sample from the logits using temperature scaling with optional top-k and top-p filtering.
- Parameters:
logits (
Tensor) – Logits tensor of shape(..., vocab_size).do_sample (
bool, default:True) – Whether to sample from the distribution. If False, uses greedy selection. If True, applies temperature scaling and optional top-k and top-p filtering.temperature (
float, default:0.0) – Temperature for scaling. Higher values increase randomness. Values < 1.0 make the distribution sharper (more deterministic). Values > 1.0 make the distribution flatter (more random). Value = 0.0 is equivalent to greedy selection.top_k (
int, default:-1) – Only consider the top k tokens with highest probabilities. -1 means no filtering.top_p (
float, default:1.0) – Only consider the smallest set of tokens whose cumulative probability exceeds this threshold (nucleus sampling). 1.0 means no filtering.dtype (
dtype, default:torch.float32) – The dtype of the output tensor. If specified, the input tensor is cast to dtype before the operation is performed. This is useful for preventing data type overflows.
- Return type:
- Returns:
Sampled token indices of shape
(...).