train.callbacks¶
Trainer Callback implementations.
- class olmo_core.train.callbacks.Callback[source]¶
Bases:
StatefulTrainer callback base class.
Callbacks can be used to modify and extend the behavior of the trainer loop. This module contains a number of useful
Callbackimplementations, but you can always add your own.-
priority:
ClassVar[int] = 0¶ Priority of the callback. Determines the order in which callbacks run relative to each other. The higher the priority, the earlier a callback runs.
- block_ephemeral_checkpoints()[source]¶
Register this callback as blocking ephemeral checkpoint saves. Ephemeral saves are blocked as long as at least one callback is registered.
- unblock_ephemeral_checkpoints()[source]¶
Unregister this callback from blocking ephemeral checkpoint saves.
- pre_optim_step()[source]¶
Runs right after the forward-backward passes, right before the optimizer step.
- pre_log_metrics(step, metrics)[source]¶
Called when metrics have been gathered for a given step (possibly a previous step), but right before
log_metrics(). This can used to modify, add, or remove metrics by updating themetricsdict in-place.
-
priority:
- class olmo_core.train.callbacks.CallbackConfig[source]¶
-
An alternative way to define callbacks when the callback class itself can’t be serialized.
- class olmo_core.train.callbacks.CheckpointerCallback(save_interval=250, ephemeral_save_interval=None, pre_train_checkpoint=None, save_async=None, remove='ephemeral_only', ephemeral_cooldown=None, fixed_steps=None, enabled=True, _latest_checkpoint_step=-1, _latest_checkpoint_path='', _checkpoints=<factory>, _ephemeral_checkpoints=<factory>, _checkpoints_to_remove=<factory>)[source]¶
Bases:
CallbackManages checkpointing during training, including writing checkpoints at set intervals determined by
save_intervalandephemeral_save_interval, as well as removing old checkpoints found in the save folder as determined by theremovesetting.Important
This callback gets added automatically if you don’t explicitly configure it. If you want to override this callback you should subclass it.
-
priority:
ClassVar[int] = 1¶ Priority of the callback. Determines the order in which callbacks run relative to each other. The higher the priority, the earlier a callback runs.
-
save_interval:
Optional[int] = 250¶ The interval, in steps, with which to save permanent checkoints.
-
ephemeral_save_interval:
Optional[int] = None¶ The interval, in steps, with which to save temporary checkpoints. These checkpoints are removed each time a new checkpoint is saved.
It can be useful to set this to a relatively frequent interval for preemptible jobs.
-
pre_train_checkpoint:
Optional[bool] = None¶ Save a pretrain checkpoint. Defaults to
Trueunless the trainer resumes from a checkpoint.
-
save_async:
Optional[bool] = None¶ Save checkpoints asynchronously. Requires a separate CPU-only backend. Defaults to
Trueif there is one.
-
remove:
CheckpointRemovalStrategy= 'ephemeral_only'¶ The strategy for removing old checkpoints found in the save folder.
-
priority:
- class olmo_core.train.callbacks.CheckpointRemovalStrategy(value)[source]¶
Bases:
StrEnumAn enumeration of the different strategies for removing old checkpoints found in the save folder.
- ephemeral_only = 'ephemeral_only'¶
Only remove checkpoints that were saved at the
CheckpointerCallback.ephemeral_save_interval.
- all_non_permanent = 'all_non_permanent'¶
Remove all non-permanent checkpoints found, including ephemeral checkpoints and also any other checkpoints that were not saved at the
CheckpointerCallback.save_interval.
- never = 'never'¶
Never remove any old checkpoints found in the save folder.
- class olmo_core.train.callbacks.CometCallback(enabled=True, name=None, project=None, workspace=None, tags=None, config=None, cancel_tags=<factory>, cancel_check_interval=None, notifications='none', failure_tag='failed', auto_resume=False, _exp_key=None, _finalized=False)[source]¶
Bases:
CallbackLogs metrics to Comet.ml from rank 0.
Important
Requires the
comet_mlpackage and the environment variableCOMET_API_KEY.Note
This callback logs metrics from every single step to Comet.ml, regardless of the value of
Trainer.metrics_collect_interval.-
cancel_tags:
Optional[List[str]]¶ If you add any of these tags to an experiment on Comet.ml, the run will cancel itself. Defaults to
["cancel", "canceled", "cancelled"].
-
cancel_check_interval:
Optional[int] = None¶ Check for cancel tags every this many steps. Defaults to
olmo_core.train.Trainer.cancel_check_interval.
-
notifications:
CometNotificationSetting= 'none'¶ The notification settings.
-
cancel_tags:
- class olmo_core.train.callbacks.CometNotificationSetting(value)[source]¶
Bases:
StrEnumDefines the notifications settings for the Comet.ml callback.
- all = 'all'¶
Send all types notifications.
- end_only = 'end_only'¶
Only send a notification when the experiment ends (successfully or with a failure).
- failure_only = 'failure_only'¶
Only send a notification when the experiment fails.
- none = 'none'¶
Don’t send any notifcations.
- class olmo_core.train.callbacks.ConfigSaverCallback(fname='config.json', save_data_paths=None, data_paths_fname=None, _config=None)[source]¶
Bases:
CallbackA callback that writes an arbitrary JSON-serializable config dictionary (
config) to every checkpoint directory written during training. It will also set the config to save for other callbacks, including theWandBCallback,CometCallback, and others, if not already set.Important
The
configshould be set after initializing the trainer and attaching all other callbacks.
- class olmo_core.train.callbacks.ConsoleLoggerCallback(log_interval=1, metrics_log_interval=None, metrics=<factory>)[source]¶
Bases:
CallbackLogs progress and a subset of metrics to the console.
Important
This callback gets added automatically if you don’t explicitly configure it. If you want to override this callback you should subclass it.
-
metrics_log_interval:
Optional[int] = None¶ How often, in steps, to log metrics to the console. If not set, defaults to
log_interval.
-
metrics_log_interval:
- class olmo_core.train.callbacks.EvaluatorCallback(evaluators=<factory>, eval_interval=1000, fixed_steps=None, eval_on_startup=False, eval_on_finish=False, cancel_after_first_eval=False, eval_duration=<factory>, log_interval=5)[source]¶
Bases:
CallbackRuns in-loop evaluations for a
TransformerTrainModuleperiodically during training.
- class olmo_core.train.callbacks.LMEvaluatorCallbackConfig(eval_dataset, eval_interval=1000, fixed_steps=None, eval_on_startup=False, eval_on_finish=False, cancel_after_first_eval=False, eval_duration=<factory>, log_interval=5, deterministic=True, enabled=True)[source]¶
Bases:
CallbackConfig
- class olmo_core.train.callbacks.DownstreamEvaluatorCallbackConfig(tasks, tokenizer, eval_interval=1000, fixed_steps=None, eval_duration=<factory>, eval_on_startup=False, eval_on_finish=False, cancel_after_first_eval=False, log_interval=5, lazy=False, enabled=True)[source]¶
Bases:
CallbackConfig
- class olmo_core.train.callbacks.GAPMonitorCallback(enabled=True, monitor=None, interval=1, dump_gradients=None, dump_gradients_start_step=0, dump_gradients_end_step=None, dump_gradients_step_interval=1, dump_gradients_save_first_n=None, _handles=None, _local_batch_size_instances=1, _dry_run_complete=False)[source]¶
Bases:
CallbackGradient, activation, and parameter (GAP) monitoring callback.
This callback logs fine-grained statistics on all gradients, activations, and parameters.
It can also dump raw gradient tensors to disk for offline analysis. Set
dump_gradients=Trueand configure thedump_gradients_*fields to control when and how gradients are saved.-
monitor:
Optional[bool] = None¶ Whether to run GAP monitoring (forward/backward hooks, per-tensor stats). Only takes effect when
enabled=True. Defaults toTruewhenenabled=True.
-
dump_gradients:
Optional[bool] = None¶ Whether to dump raw gradient tensors to disk for offline analysis. Only takes effect when
enabled=True. Defaults toFalsewhenenabled=True.
-
monitor:
- class olmo_core.train.callbacks.GarbageCollectorCallback(gc_interval=1000, enabled=True, _start_state=None)[source]¶
Bases:
CallbackDisables automatic garbage collection during training and runs gen1 collection on a set schedule instead.
Important
This callback gets added automatically in a distributed training setting if you don’t explicitly configure it. If you want to override this callback you should subclass it.
- class olmo_core.train.callbacks.GPUMemoryMonitorCallback(device_id=None, _num_alloc_retries=0)[source]¶
Bases:
CallbackAdds metrics for GPU memory statistics.
- class olmo_core.train.callbacks.HFConverterCallback(enabled=True, output_folder=None, dtype='bfloat16', validate=False, debug=False, tokenizer_id=None, max_sequence_length=None, device=None, moe_capacity_factor=None)[source]¶
Bases:
CallbackConverts the final saved checkpoint to HuggingFace format at the end of a training job.
This callback runs after training completes and uses
olmo_core.nn.hf.convert_checkpoint_to_hf()to convert the final OLMo Core checkpoint to a HuggingFace-compatible format.Note
This callback requires the
transformerslibrary to be installed.Warning
In distributed training, ALL ranks must participate in this callback because gathering the full model state dict from FSDP requires collective operations. Only rank 0 performs the actual HF conversion and saving.
-
priority:
ClassVar[int] = -1¶ Priority of the callback. Determines the order in which callbacks run relative to each other. The higher the priority, the earlier a callback runs.
-
output_folder:
Optional[str] = None¶ The folder to save the HuggingFace checkpoint to. If not specified, defaults to
{checkpoint_path}-hfwherecheckpoint_pathis the final checkpoint path.
-
dtype:
Optional[DType] = 'bfloat16'¶ The dtype to save the HuggingFace model weights as. Defaults to bfloat16.
-
validate:
bool= False¶ Whether to validate the converted model against the original model. Validation loads both models and compares their outputs.
-
debug:
bool= False¶ Whether to output debug information during validation. Only has an effect if
validateisTrue.
-
tokenizer_id:
Optional[str] = None¶ The HuggingFace tokenizer identifier to save with the model. If not specified, uses the tokenizer from the experiment config.
-
priority:
- class olmo_core.train.callbacks.ProfilerCallback(skip_first=0, wait=1, warmup=5, active=3, repeat=1, with_stack=True, profile_memory=False, enable_cuda_sync_events=False, enabled=True, ranks=None, _first_batch=True)[source]¶
Bases:
CallbackEnables profiling/tracing of training steps using
torch.profiler. Saved the results to a subdirectory of the save folder named “profiler”.-
enable_cuda_sync_events:
bool= False¶ Whether to enable recording of CUDA sync events. Useful for critical-path analysis with https://hta.readthedocs.io/en/latest/source/features/lightweight_critical_path_analysis.html
-
ranks:
Optional[str] = None¶ Ranks to profile. Can be:
None: Only rank 0 is profiledString shortcuts: -
"dp": Profile one rank (local rank 0) in each data parallel group -"tp": Profile one rank (local rank 0) in each tensor parallel group -"cp": Profile one rank (local rank 0) in each context parallel group -"pp": Profile one rank (local rank 0) in each pipeline parallel group -"ep": Profile one rank (local rank 0) in each expert parallel group -"all": Profile all ranks
Useful in conjunction with https://github.com/facebookresearch/HolisticTraceAnalysis to analyze traces from a distributed training job.
-
enable_cuda_sync_events:
- class olmo_core.train.callbacks.SlackNotifierCallback(name=None, notifications='end_only', enabled=True, webhook_url=None)[source]¶
Bases:
Callback-
notifications:
SlackNotificationSetting= 'end_only'¶ The notification settings.
-
notifications:
- class olmo_core.train.callbacks.SlackNotificationSetting(value)[source]¶
Bases:
StrEnumDefines the notifications settings for the Slack notifier callback.
- all = 'all'¶
Send all types notifications.
- end_only = 'end_only'¶
Only send a notification when the experiment ends (successfully or with a failure).
- failure_only = 'failure_only'¶
Only send a notification when the experiment fails.
- none = 'none'¶
Don’t send any notifications.
- class olmo_core.train.callbacks.SequenceLengthSchedulerCallback(min_sequence_length=128, warmup_steps=2000, truncate=False, keep_multiple_of=128, enabled=True, _og_rank_microbatch_size=None, _last_seq_len=None)[source]¶
Bases:
CallbackA
Callbackfor introducing a linear sequence-length warm-up schedule over the course ofwarmup_stepsstarting frommin_sequence_lengthand ending at the configured training sequence length (NumpyFSLDataset.sequence_length <olmo_core.data.NumpyFSLDataset.sequence_length).When
truncateisFalsethe scheduler works by splitting each instance in a batch into more shorter instances while maintaining the same number of tokens in each batch and micro-batch. In this case the sequence length set during the warm-up will always be a multiple ofmin_sequence_lengthby a power of 2, and therefore the train sequence length must be a multiple ofmin_sequence_lengthby a power of 2.Otherwise the scheduler simply truncates the instances in the batch to the desired sequence length, throwing out the extra tokens. The scheduler will ensure the sequence length during the warm-up is always a multiple of
keep_multiple_of.Important
This callback is only compatible with a
NumpyFSLDataLoadertrainingdata_loader.Note
The “total tokens” recorded by the trainer and
SpeedMonitorCallbackwill still include tokens truncated by this callback for bookkeeping purposes.
- class olmo_core.train.callbacks.SpeedMonitorCallback(num_flops_per_token=None, num_params=None, device_peak_flops_per_second=None, _total_steps=0, _total_tokens=0, _total_flops=0, _start_time=0.0, _first_step=True, _step_last_logged=0.0, _batch_load_start=0.0, _batch_load_time=0.0, _step_tokens=0, _step_seq_len=0, _step_flops=0, _parallel_degree=1, _bps_avg=None, _tps_avg=None, _mfu_avg=None)[source]¶
Bases:
CallbackMonitors throughput.
Important
This callback gets added automatically if you don’t explicitly configure it. If you want to override this callback you should subclass it.
- class olmo_core.train.callbacks.StabilityMonitorCallback(window_size=128, rolling_window=10000, threshold_std=6.0, enabled=True, loss_metric_name='train/CE loss', grad_norm_metric_name='optim/total grad norm', _loss_history=<factory>, _grad_norm_history=<factory>, _spike_history=<factory>, _total_spike_count=0, _total_step_count=0)[source]¶
Bases:
CallbackMonitors training stability by tracking “spikes” in loss and gradient norm.
A spike is detected when a value exceeds the running mean of the last
window_sizevalues by more thanthreshold_stdstandard deviations. This helps identify training instability.Metrics recorded:
spike/SpikeScore: Running spike rate over the lastrolling_windowsteps. Only recorded once the rolling window is full.spike/SpikeScore (total): Cumulative spike rate (total spikes / total steps).
- class olmo_core.train.callbacks.WandBCallback(enabled=True, name=None, project=None, entity=None, group=None, tags=None, notes=None, config=None, cancel_tags=<factory>, cancel_check_interval=None, _finalized=False)[source]¶
Bases:
CallbackLogs metrics to Weights & Biases from rank 0.
Important
Requires the
wandbpackage and the environment variableWANDB_API_KEY.Note
This callback logs metrics from every single step to W&B, regardless of the value of
Trainer.metrics_collect_interval.-
cancel_tags:
Optional[List[str]]¶ If you add any of these tags to a run on W&B, the run will cancel itself. Defaults to
["cancel", "canceled", "cancelled"].
-
cancel_check_interval:
Optional[int] = None¶ Check for cancel tags every this many steps. Defaults to
olmo_core.train.Trainer.cancel_check_interval.
-
cancel_tags:
- class olmo_core.train.callbacks.BeakerCallback(experiment_id=None, update_interval=None, description=None, enabled=None, config=None, result_dir='/results', _url=None, _last_update=None)[source]¶
Bases:
CallbackAdds metadata to the Beaker experiment description when running as a Beaker batch job.
-
priority:
ClassVar[int] = -1¶ Priority of the callback. Determines the order in which callbacks run relative to each other. The higher the priority, the earlier a callback runs.
-
priority:
- class olmo_core.train.callbacks.BatchSizeSchedulerCallback(batch_sizes=<factory>, schedule=<factory>, enabled=True)[source]¶
Bases:
CallbackA callback for setting a batch size scheduler over the course of a training run. Also adjusts the base learning rate with Adam optimizers for transformer train modules by a factor of
sqrt(new_batch_size / current_batch_size).
- class olmo_core.train.callbacks.MonkeyPatcherCallback[source]¶
Bases:
CallbackWhile looking into performance issues with OLMo3 training, we discovered that DeviceMesh.__getitem__() can become a bottleneck because it gets called very often by FSDP and creates a new sub-mesh object each time. So this callback patches that method to cache the sub-meshes.
- class olmo_core.train.callbacks.MetricSaverCallback(step_metrics_fname='metrics_step{step}.json', final_metrics_fname='metrics.json', metrics_to_capture=None, save_interval=None, fixed_steps=None, enabled=True, _metrics=None, _metrics_step=0)[source]¶
Bases:
CallbackA callback that captures the latest metrics on rank 0 and saves to a JSON file in the trainer’s
save_folder.-
step_metrics_fname:
str= 'metrics_step{step}.json'¶ The filename to save the step metrics to, with
{step}as a placeholder for the step number.
-
metrics_to_capture:
Optional[List[str]] = None¶ An optional list of glob patterns to filter which metrics to capture. If
None, all metrics are captured.
-
step_metrics_fname:
- class olmo_core.train.callbacks.ModelMergeCallback(merge_step=<factory>, merge_interval=None, merge_last_n_steps=500, output_suffix='merged', enabled=False, _accumulators=<factory>, _accumulator_counts=<factory>, _merge_steps=<factory>, _completed_merges=<factory>)[source]¶
Bases:
CallbackAverages model weights over the last
merge_last_n_stepsbefore eachmerge_stepand saves the result as a merged checkpoint.Ephemeral checkpoints are blocked during merge windows to ensure the full window is always re-accumulated on resume.
Warning
This callback should be enabled with intention and configured with your training schedule in mind. Merge steps should be configured outside of decay phases where possible to ensure the averaged weights reflect a stable training regime.
-
priority:
ClassVar[int] = 2¶ Priority of the callback. Determines the order in which callbacks run relative to each other. The higher the priority, the earlier a callback runs.
-
priority:
- class olmo_core.train.callbacks.ListCheckpointerCallback(save_interval=1000000000, ephemeral_save_interval=None, pre_train_checkpoint=None, save_async=None, remove='ephemeral_only', ephemeral_cooldown=None, fixed_steps=None, enabled=True, _latest_checkpoint_step=-1, _latest_checkpoint_path='', _checkpoints=<factory>, _ephemeral_checkpoints=<factory>, _checkpoints_to_remove=<factory>, save_steps=None)[source]¶
Bases:
CheckpointerCallbackSave checkpoints only at specific steps provided in a list.
Pass ‘save_steps’ as a sorted list of step numbers (integers) at which to save. All other base behavior (async save, removal) is preserved.
This is useful for saving at predetermined milestones, such as: - Period boundaries in WSD-S schedules (when LR = 0) - Specific token budgets - Other training milestones
Example
save_steps = [100, 500, 1000, 2000] # save at these exact steps