[docs]@dataclassclassSequenceLengthSchedulerCallback(Callback):""" A :class:`Callback` for introducing a linear sequence-length warm-up schedule over the course of :data:`warmup_steps` starting from :data:`min_sequence_length` and ending at the configured training sequence length (:data:`NumpyFSLDataset.sequence_length <olmo_core.data.NumpyFSLDataset.sequence_length`). When :data:`truncate` is ``False`` the scheduler works by splitting each instance in a batch into more shorter instances while maintaining the same number of tokens in each batch and micro-batch. In this case the sequence length set during the warm-up will always be a multiple of :data:`min_sequence_length` by a power of 2, and therefore the train sequence length must be a multiple of :data:`min_sequence_length` by a power of 2. Otherwise the scheduler simply truncates the instances in the batch to the desired sequence length, throwing out the extra tokens. The scheduler will ensure the sequence length during the warm-up is always a multiple of :data:`keep_multiple_of`. .. important:: This callback is only compatible with a :class:`~olmo_core.data.data_loader.NumpyFSLDataLoader` training :data:`~olmo_core.train.Trainer.data_loader`. .. note:: The "total tokens" recorded by the trainer and :class:`SpeedMonitorCallback` will still include tokens truncated by this callback for bookkeeping purposes. """min_sequence_length:int=128warmup_steps:int=2000truncate:bool=Falsekeep_multiple_of:int=128enabled:bool=True_og_rank_microbatch_size:Optional[int]=None_last_seq_len:Optional[int]=Nonedefpre_train(self):ifnotself.enabled:returnifnotisinstance(self.trainer.data_loader,NumpyFSLDataLoader):raiseOLMoConfigurationError("The sequence length scheduler callback requires a 'NumpyFSLDataLoader', "f"got '{type(self.trainer.data_loader)}' instead")ifnotisinstance(self.trainer.train_module,TransformerTrainModule):raiseOLMoConfigurationError("The sequence length scheduler callback requires a 'TransformerTrainModule', "f"got '{type(self.trainer.train_module)}' instead")dataset=self.trainer.data_loader.datasetassertisinstance(dataset,NumpyFSLDataset)ifself.truncateand(dataset.sequence_length%self.min_sequence_length!=0or(math.log(dataset.sequence_length//self.min_sequence_length,2)%1!=0)):raiseOLMoConfigurationError("train sequence length must be a multiple of 'min_sequence_length' by a power of 2 ""when 'truncate=False'.")elifdataset.sequence_length<=self.min_sequence_length:raiseOLMoConfigurationError("train sequence length must be greater than 'min_sequence_length'")self._og_rank_microbatch_size=self.trainer.train_module.rank_microbatch_sizedefpre_step(self,batch:Dict[str,Any]):ifnotself.enabled:returnifself.step>self.warmup_steps:returnassertisinstance(self.trainer.data_loader,NumpyFSLDataLoader)dataset=self.trainer.data_loader.datasetassertisinstance(dataset,NumpyFSLDataset)assertisinstance(self.trainer.train_module,TransformerTrainModule)new_seq_len:intifself.truncate:new_seq_len=_get_truncated_sequence_length(self.min_sequence_length,dataset.sequence_length,self.step,self.warmup_steps,self.keep_multiple_of,)forkey,valueintruncate_batch(batch,new_seq_len,).items():batch[key]=valueelse:new_seq_len=_get_split_sequence_length(self.min_sequence_length,dataset.sequence_length,self.step,self.warmup_steps,)forkey,valueinmelt_batch(batch,new_seq_len,).items():batch[key]=value# Increase micro-batch size proportionally to maintain the same number of tokens# in each micro-batch.assertself._og_rank_microbatch_sizeisnotNonenew_rank_microbatch_size=self._og_rank_microbatch_size*(dataset.sequence_length//new_seq_len)self.trainer.train_module.rank_microbatch_size=new_rank_microbatch_sizeifnew_seq_len!=self._last_seq_len:log.info(f"Changing sequence length to {new_seq_len} per warm-up schedule")self._last_seq_len=new_seq_len# Empty CUDA cache since shapes have now changed.gc_cuda()defpost_train_batch(self):ifnotself.enabledorself.step>self.warmup_steps+1:returnassertisinstance(self.trainer.train_module,TransformerTrainModule)assertself._og_rank_microbatch_sizeisnotNoneself.trainer.train_module.rank_microbatch_size=self._og_rank_microbatch_size