generate.generation_module¶
- class olmo_core.generate.generation_module.GenerationConfig(pad_token_id, eos_token_id, max_length=None, max_new_tokens=None, do_sample=True, temperature=0.0, top_k=-1, top_p=1.0, use_cache=True, stop_token_ids=None)[source]¶
Bases:
ConfigConfiguration for text generation.
-
max_new_tokens:
Optional[int] = None¶ Maximum number of new tokens to generate. If provided, this takes precedence over max_length.
-
do_sample:
bool= True¶ Whether to use sampling for generation. If False, greedy decoding is used. This overrides temperature, top_k, and top_p.
-
top_k:
int= -1¶ Top-k sampling. Only consider the top k tokens with the highest probabilities. -1 means no filtering.
-
top_p:
float= 1.0¶ Top-p (nucleus) sampling. Only consider the smallest set of tokens whose cumulative probability exceeds this threshold. 1.0 means no filtering.
-
max_new_tokens:
- class olmo_core.generate.generation_module.GenerationModule(*args, **kwargs)[source]¶
Bases:
Stateful- property dp_process_group: ProcessGroup | None¶
Should return the data parallel process group if it’s anything other than the default process group.
- class olmo_core.generate.generation_module.TransformerGenerationModule(model, generation_config, compile_model=False, float8_config=None, dp_config=None, device=None, state_dict_load_opts=None, state_dict_save_opts=None, load_key_mapping=None)[source]¶
Bases:
GenerationModuleModule for autoregressive text generation with transformer models.
- property dp_process_group: ProcessGroup | None¶
Should return the data parallel process group if it’s anything other than the default process group.
- state_dict()[source]¶
Objects should return their state_dict representation as a dictionary. The output of this function will be checkpointed, and later restored in load_state_dict().
Warning
Because of the inplace nature of restoring a checkpoint, this function is also called during torch.distributed.checkpoint.load.
- Returns:
The objects state dict
- Return type:
Dict
- generate_batch(input_ids, *, attention_mask=None, return_logits=False, return_logprobs=False, completions_only=False, log_timing=True, **generation_kwargs)[source]¶
Generate text with autoregressive decoding.
- Parameters:
input_ids (
Tensor) – Input token IDs of shape(batch_size, seq_len).attention_mask (
Optional[Tensor], default:None) – Optional attention mask of shape(batch_size, seq_len). This should be a left-padding mask, not an arbitrary attention mask. If not provided, the model will assume there are no left-padding tokens.return_logits (
bool, default:False) – IfTrue, return logits along with generated tokens.return_logprobs (
bool, default:False) – IfTrue, return log probabilities for the generated tokens along with generated tokens. This is notably more memory efficient thanreturn_logits.completions_only (
bool, default:False) – IfTrue, return only the completions, not the entire sequence.generation_kwargs – Generation configuration overrides.
- Return type:
- Returns:
Tuple of
(generated_ids, logits, logprobs)where: -generated_ids: Generated token IDs of shape(batch_size, output_length). -logits: Full logits ifreturn_logits=True, elseNone. Shape:(batch_size, output_length, vocab_size). -logprobs: Log probabilities of generated tokens ifreturn_logprobs=True, elseNone. Shape:(batch_size, output_length).
- load_checkpoint(checkpoint_dir, work_dir, process_group=None, pre_download=True, load_thread_count=None)[source]¶
Load model checkpoint.
- Parameters:
checkpoint_dir (
Union[Path,PathLike,str]) – Path to checkpoint directorywork_dir (
Union[Path,PathLike,str]) – Working directory for caching remote checkpointsprocess_group (
Optional[ProcessGroup], default:None) – Process group for distributed loadingpre_download (
bool, default:True) – Whether to pre-download remote checkpointsload_thread_count (
Optional[int], default:None) – Number of threads to use for loading the checkpoint
- Raises:
FileNotFoundError – If checkpoint directory doesn’t exist
RuntimeError – If checkpoint loading fails
- classmethod from_checkpoint(checkpoint_dir, *, transformer_config=None, generation_config=None, process_group=None, work_dir=None, pre_download=True, load_thread_count=None, dtype=None, attention_backend=None, **kwargs)[source]¶
Create a GenerationModule from a checkpoint.
This is a convenience method that combines model initialization and checkpoint loading.
- Parameters:
checkpoint_dir (
Union[Path,PathLike,str]) – Path to checkpoint directory.transformer_config (
Optional[TransformerConfig], default:None) – Configuration for the transformer model. If not provided, will be loaded from the checkpoint’s config.json file.generation_config (
Optional[GenerationConfig], default:None) – Configuration for generation. If not provided, uses default GenerationConfig.process_group (
Optional[ProcessGroup], default:None) – Process group for distributed checkpoint loading.pre_download (
bool, default:True) – Whether to pre-download remote checkpoints.load_thread_count (
Optional[int], default:None) – Number of threads to use for loading checkpoint.dtype (
Optional[DType], default:None) – If provided, build the model with this dtype.attention_backend (
Optional[AttentionBackendName], default:None) – If provided, override the config to use this attention backend.**kwargs – Additional keyword arguments passed to the TransformerGenerationModule constructor.
- Return type:
- Returns:
TransformerGenerationModule instance with loaded checkpoint.
- Raises:
FileNotFoundError – If checkpoint directory doesn’t exist.
OLMoConfigurationError – If transformer config cannot be determined.
RuntimeError – If checkpoint loading fails.
- class olmo_core.generate.generation_module.TransformerGenerationModuleConfig(generation_config, compile_model=False, float8_config=None, dp_config=None, state_dict_load_opts=None, load_key_mapping=None, dtype=None)[source]¶
Bases:
ConfigA configuration class for building
TransformerGenerationModuleinstances.Warning
This is a beta feature! The API is subject to change even with minor and patch releases. If you choose to use this feature please read the CHANGELOG before upgrading your version of this library.
- build(checkpoint_dir, transformer_config=None, device=None, process_group=None, work_dir=None, pre_download=True, load_thread_count=None)[source]¶
Build the corresponding
TransformerGenerationModule.- Parameters:
checkpoint_dir (
Union[Path,PathLike,str,List[Union[Path,PathLike,str]]]) – Checkpoint directory to load from.transformer_config (
Optional[TransformerConfig], default:None) – TheTransformerConfigto use for generation.device (
Optional[device], default:None) – The device to use for generation.process_group (
Optional[ProcessGroup], default:None) – The process group for distributed loading.work_dir (
Union[Path,PathLike,str,None], default:None) – Working directory for temporary files during loading.pre_download (
bool, default:True) – Whether to pre-download remote checkpoints.load_thread_count (
Optional[int], default:None) – Number of threads to use for loading.
- Return type: