[docs]@dataclassclassGenerationConfig(Config):"""Configuration for text generation."""pad_token_id:int"""Padding token ID."""eos_token_id:int"""End of sequence token ID."""max_length:Optional[int]=None"""Maximum length of input + newly generated tokens."""max_new_tokens:Optional[int]=None"""Maximum number of new tokens to generate. If provided, this takes precedence over max_length."""do_sample:bool=True"""Whether to use sampling for generation. If False, greedy decoding is used. This overrides temperature, top_k, and top_p."""temperature:float=0.0"""Temperature for sampling. If 0, this is equivalent to greedy selection."""top_k:int=-1"""Top-k sampling. Only consider the top k tokens with the highest probabilities. -1 means no filtering."""top_p:float=1.0"""Top-p (nucleus) sampling. Only consider the smallest set of tokens whose cumulative probability exceeds this threshold. 1.0 means no filtering."""use_cache:bool=True"""Whether to use an inference cache (e.g. a kv-cache) for generation."""stop_token_ids:Optional[List[int]]=None"""Tokens to stop generation at. If provided, the generation will stop when any of these tokens are generated."""def__post_init__(self):self.validate()
[docs]defvalidate(self):"""Validate the generation configuration."""ifself.pad_token_id<0:raiseValueError(f"pad_token_id must be non-negative, got {self.pad_token_id}")ifself.eos_token_id<0:raiseValueError(f"eos_token_id must be non-negative, got {self.eos_token_id}")ifself.pad_token_id==self.eos_token_id:raiseValueError(f"pad_token_id and eos_token_id must be different, got {self.pad_token_id} and {self.eos_token_id}")ifself.max_lengthisnotNoneandself.max_length<=0:raiseValueError(f"max_length must be positive, got {self.max_length}")ifself.max_new_tokensisnotNoneandself.max_new_tokens<=0:raiseValueError(f"max_new_tokens must be positive, got {self.max_new_tokens}")ifself.temperature<0:raiseValueError(f"temperature must be non-negative, got {self.temperature}")ifself.top_k<=0andself.top_k!=-1:raiseValueError(f"top_k must be positive or -1, got {self.top_k}")ifself.top_p<=0.0orself.top_p>1.0:raiseValueError(f"top_p must be in (0, 1], got {self.top_p}")