[docs]defgreedy_selection(logits:torch.Tensor)->torch.Tensor:""" Deterministically select the next token as the one with the highest logit. :param logits: Logits tensor of shape ``(..., vocab_size)``. :returns: Selected token indices of shape ``(...)``. """returnlogits.argmax(dim=-1)
[docs]deftop_k_filtering(logits:torch.Tensor,top_k:int)->torch.Tensor:""" Filter logits to keep only the top k tokens. :param logits: Logits tensor of shape ``(..., vocab_size)``. :param top_k: Number of top tokens to keep. :returns: Filtered logits with -inf for tokens outside top k. """iftop_k<=0:returnlogits# Get the kth largest valuekth_values,_=torch.topk(logits,k=min(top_k,logits.size(-1)),dim=-1)kth_value=kth_values[...,-1].unsqueeze(-1)# Set all logits below the kth value to -infreturntorch.where(logits<kth_value,torch.full_like(logits,float("-inf")),logits)
[docs]deftop_p_filtering(logits:torch.Tensor,top_p:float)->torch.Tensor:""" Filter logits using nucleus (top-p) sampling. :param logits: Logits tensor of shape ``(..., vocab_size)``. :param top_p: Cumulative probability threshold for nucleus sampling. :returns: Filtered logits with -inf for tokens outside the nucleus. """iftop_p<=0.0ortop_p>=1.0:returnlogitssorted_logits,sorted_indices=torch.sort(logits,descending=True,dim=-1)sorted_probs=torch.softmax(sorted_logits,dim=-1)cumulative_probs=torch.cumsum(sorted_probs,dim=-1)sorted_indices_to_remove=cumulative_probs>top_psorted_indices_to_remove[...,0]=False# Shift the mask to include the token that crosses the thresholdsorted_indices_to_remove[...,1:]=sorted_indices_to_remove[...,:-1].clone()indices_to_remove=sorted_indices_to_remove.scatter(-1,sorted_indices,sorted_indices_to_remove)# Set filtered tokens to -infreturnlogits.masked_fill(indices_to_remove,float("-inf"))
[docs]defselect_next_token(logits:torch.Tensor,do_sample:bool=True,temperature:float=0.0,top_k:int=-1,top_p:float=1.0,dtype:torch.dtype=torch.float32,)->torch.Tensor:""" Sample from the logits using temperature scaling with optional top-k and top-p filtering. :param logits: Logits tensor of shape ``(..., vocab_size)``. :param do_sample: Whether to sample from the distribution. If False, uses greedy selection. If True, applies temperature scaling and optional top-k and top-p filtering. :param temperature: Temperature for scaling. Higher values increase randomness. Values < 1.0 make the distribution sharper (more deterministic). Values > 1.0 make the distribution flatter (more random). Value = 0.0 is equivalent to greedy selection. :param top_k: Only consider the top k tokens with highest probabilities. -1 means no filtering. :param top_p: Only consider the smallest set of tokens whose cumulative probability exceeds this threshold (nucleus sampling). 1.0 means no filtering. :param dtype: The dtype of the output tensor. If specified, the input tensor is cast to dtype before the operation is performed. This is useful for preventing data type overflows. :returns: Sampled token indices of shape ``(...)``. """ifnotdo_sampleortemperature==0:returngreedy_selection(logits)nan_mask=torch.isnan(logits)num_nans=nan_mask.sum().item()total_elements=logits.numel()nan_percentage=(num_nans/total_elements)*100iftotal_elements>0else0batch_nan_info=[iforiinrange(logits.shape[0])iftorch.isnan(logits[i]).any()]assertnotnan_mask.any(),(f"NaN values detected in logits: {num_nans}/{total_elements} ({nan_percentage:.2f}%) "f"NaN values in tensor of shape {logits.shape}"+(f" in batch elements: {', '.join(map(str,batch_nan_info))}"ifbatch_nan_infoelse""))scaled_logits=logits/temperatureiftop_k!=-1:scaled_logits=top_k_filtering(scaled_logits,top_k)iftop_p!=1.0:scaled_logits=top_p_filtering(scaled_logits,top_p)probs=torch.softmax(scaled_logits,dim=-1,dtype=dtype)returntorch.multinomial(probs,num_samples=1).squeeze(-1)