generate.generation_module

class olmo_core.generate.generation_module.GenerationConfig(pad_token_id, eos_token_id, max_length=None, max_new_tokens=None, do_sample=True, temperature=0.0, top_k=-1, top_p=1.0, use_cache=True, stop_token_ids=None)[source]

Bases: Config

Configuration for text generation.

pad_token_id: int

Padding token ID.

eos_token_id: int

End of sequence token ID.

max_length: Optional[int] = None

Maximum length of input + newly generated tokens.

max_new_tokens: Optional[int] = None

Maximum number of new tokens to generate. If provided, this takes precedence over max_length.

do_sample: bool = True

Whether to use sampling for generation. If False, greedy decoding is used. This overrides temperature, top_k, and top_p.

temperature: float = 0.0

Temperature for sampling. If 0, this is equivalent to greedy selection.

top_k: int = -1

Top-k sampling. Only consider the top k tokens with the highest probabilities. -1 means no filtering.

top_p: float = 1.0

Top-p (nucleus) sampling. Only consider the smallest set of tokens whose cumulative probability exceeds this threshold. 1.0 means no filtering.

use_cache: bool = True

Whether to use an inference cache (e.g. a kv-cache) for generation.

stop_token_ids: Optional[List[int]] = None

Tokens to stop generation at. If provided, the generation will stop when any of these tokens are generated.

validate()[source]

Validate the generation configuration.

class olmo_core.generate.generation_module.GenerationModule(*args, **kwargs)[source]

Bases: Stateful

property dp_process_group: ProcessGroup | None

Should return the data parallel process group if it’s anything other than the default process group.

abstract load_state_dict(state_dict)[source]

Load a state dict.

Return type:

None

class olmo_core.generate.generation_module.TransformerGenerationModule(model, generation_config, compile_model=False, float8_config=None, dp_config=None, device=None, state_dict_load_opts=None, state_dict_save_opts=None, load_key_mapping=None)[source]

Bases: GenerationModule

Module for autoregressive text generation with transformer models.

property dp_process_group: ProcessGroup | None

Should return the data parallel process group if it’s anything other than the default process group.

state_dict()[source]

Objects should return their state_dict representation as a dictionary. The output of this function will be checkpointed, and later restored in load_state_dict().

Warning

Because of the inplace nature of restoring a checkpoint, this function is also called during torch.distributed.checkpoint.load.

Returns:

The objects state dict

Return type:

Dict

load_state_dict(state_dict)[source]

Load a state dict.

Return type:

None

generate_batch(input_ids, *, attention_mask=None, return_logits=False, return_logprobs=False, completions_only=False, log_timing=True, **generation_kwargs)[source]

Generate text with autoregressive decoding.

Parameters:
  • input_ids (Tensor) – Input token IDs of shape (batch_size, seq_len).

  • attention_mask (Optional[Tensor], default: None) – Optional attention mask of shape (batch_size, seq_len). This should be a left-padding mask, not an arbitrary attention mask. If not provided, the model will assume there are no left-padding tokens.

  • return_logits (bool, default: False) – If True, return logits along with generated tokens.

  • return_logprobs (bool, default: False) – If True, return log probabilities for the generated tokens along with generated tokens. This is notably more memory efficient than return_logits.

  • completions_only (bool, default: False) – If True, return only the completions, not the entire sequence.

  • generation_kwargs – Generation configuration overrides.

Return type:

Tuple[Tensor, Optional[Tensor], Optional[Tensor]]

Returns:

Tuple of (generated_ids, logits, logprobs) where: - generated_ids: Generated token IDs of shape (batch_size, output_length). - logits: Full logits if return_logits=True, else None. Shape: (batch_size, output_length, vocab_size). - logprobs: Log probabilities of generated tokens if return_logprobs=True, else None. Shape: (batch_size, output_length).

load_checkpoint(checkpoint_dir, work_dir, process_group=None, pre_download=True, load_thread_count=None)[source]

Load model checkpoint.

Parameters:
  • checkpoint_dir (Union[Path, PathLike, str]) – Path to checkpoint directory

  • work_dir (Union[Path, PathLike, str]) – Working directory for caching remote checkpoints

  • process_group (Optional[ProcessGroup], default: None) – Process group for distributed loading

  • pre_download (bool, default: True) – Whether to pre-download remote checkpoints

  • load_thread_count (Optional[int], default: None) – Number of threads to use for loading the checkpoint

Raises:
classmethod from_checkpoint(checkpoint_dir, *, transformer_config=None, generation_config=None, process_group=None, work_dir=None, pre_download=True, load_thread_count=None, dtype=None, attention_backend=None, **kwargs)[source]

Create a GenerationModule from a checkpoint.

This is a convenience method that combines model initialization and checkpoint loading.

Parameters:
  • checkpoint_dir (Union[Path, PathLike, str]) – Path to checkpoint directory.

  • transformer_config (Optional[TransformerConfig], default: None) – Configuration for the transformer model. If not provided, will be loaded from the checkpoint’s config.json file.

  • generation_config (Optional[GenerationConfig], default: None) – Configuration for generation. If not provided, uses default GenerationConfig.

  • process_group (Optional[ProcessGroup], default: None) – Process group for distributed checkpoint loading.

  • pre_download (bool, default: True) – Whether to pre-download remote checkpoints.

  • load_thread_count (Optional[int], default: None) – Number of threads to use for loading checkpoint.

  • dtype (Optional[DType], default: None) – If provided, build the model with this dtype.

  • attention_backend (Optional[AttentionBackendName], default: None) – If provided, override the config to use this attention backend.

  • **kwargs – Additional keyword arguments passed to the TransformerGenerationModule constructor.

Return type:

TransformerGenerationModule

Returns:

TransformerGenerationModule instance with loaded checkpoint.

Raises:
class olmo_core.generate.generation_module.TransformerGenerationModuleConfig(generation_config, compile_model=False, float8_config=None, dp_config=None, state_dict_load_opts=None, load_key_mapping=None, dtype=None)[source]

Bases: Config

A configuration class for building TransformerGenerationModule instances.

Warning

This is a beta feature! The API is subject to change even with minor and patch releases. If you choose to use this feature please read the CHANGELOG before upgrading your version of this library.

dtype: Optional[DType] = None

The dtype to build the model in.

build(checkpoint_dir, transformer_config=None, device=None, process_group=None, work_dir=None, pre_download=True, load_thread_count=None)[source]

Build the corresponding TransformerGenerationModule.

Parameters:
Return type:

TransformerGenerationModule