Source code for olmo_core.data.composable.data_loader

import dataclasses
import functools as ft
import logging
import math
from dataclasses import dataclass
from itertools import islice
from pathlib import Path
from typing import Any, Dict, Iterable, Iterator, Optional, Tuple

import numpy as np
import torch
import torch.distributed as dist

import olmo_core.distributed.utils as dist_utils
from olmo_core.aliases import PathOrStr
from olmo_core.config import Config, StrEnum
from olmo_core.distributed.parallel import get_dp_process_group
from olmo_core.exceptions import OLMoConfigurationError
from olmo_core.utils import get_default_device, roundrobin, threaded_generator

from ..collator import DataCollator
from ..data_loader import DataLoaderConfig, TextDataLoaderBase
from ..tokenizer import TokenizerConfig
from ..utils import (
    find_periodic_sequences,
    get_document_lengths,
    get_rng,
    iter_batched,
    memmap_to_write,
)
from .instance_source import InstanceSource
from .utils import (
    SEED_NOT_SET,
    as_tensor,
    build_global_indices,
    format_fname_from_fields,
    resolve_seed,
)

log = logging.getLogger(__name__)


[docs] @dataclass class InstanceFilterConfig(Config): """Config for instance filtering.""" repetition_max_period: int = 13 repetition_min_period: int = 1 repetition_max_count: int = 32
[docs] class ShuffleStrategy(StrEnum): """Defines how the data is shuffled.""" inter_source = "inter_source" """Shuffle across all sources as if they were one big source.""" intra_source = "intra_source" """ Shuffle within each source, then concatenate the sources in order. This can be used to create a data curriculum. """ interleaved_source = "interleaved_source" """ Shuffle within each source and then interleave instances from each source. """
[docs] @DataLoaderConfig.register("composable") @dataclass class ComposableDataLoaderConfig(DataLoaderConfig["ComposableDataLoader"]): """ A configuration class for building :class:`ComposableDataLoader` data loaders. """ tokenizer: Optional[TokenizerConfig] = None global_batch_size: Optional[int] = None seed: int = dataclasses.field(default_factory=lambda: resolve_seed(SEED_NOT_SET)) work_dir: Optional[str] = None shuffle: bool = True shuffle_strategy: Optional[ShuffleStrategy] = None sources_per_epoch: int = -1 num_threads: Optional[int] = None num_workers: int = 0 prefetch_factor: Optional[int] = None target_device_type: Optional[str] = None generate_doc_lengths: bool = False instance_filter_config: Optional[InstanceFilterConfig] = None display_source_visualization: bool = True def __post_init__(self, *args): del args if self.sources_per_epoch == 0 or self.sources_per_epoch < -1: raise OLMoConfigurationError( "'sources_per_epoch' must be -1 (for all sources) or a positive integer." ) if not self.shuffle and self.shuffle_strategy is not None: raise OLMoConfigurationError("'shuffle_strategy' cannot be set if 'shuffle' is False.")
[docs] def build( self, *sources: InstanceSource, collator: Optional[DataCollator] = None, work_dir: Optional[PathOrStr] = None, mesh: Optional[dist.DeviceMesh] = None, dp_process_group: Optional[dist.ProcessGroup] = None, tokenizer: Optional[TokenizerConfig] = None, global_batch_size: Optional[int] = None, ) -> "ComposableDataLoader": """ Construct the :class:`ComposableDataLoader`. :param sources: The instance sources. :param collator: An optional data collator. If not provided, a default will be created. :param work_dir: A working directory for caching. :param mesh: An optional ``DeviceMesh`` that defines the data parallel dimensions. Ideally you should create this mesh using :func:`~olmo_core.distributed.parallel.build_world_mesh()`. Alternatively you can pass the ``dp_process_group`` instead. :param dp_process_group: The data parallel process group. """ if not sources: raise OLMoConfigurationError("At least one 'source' must be provided.") if dp_process_group is None and mesh is not None: dp_process_group = get_dp_process_group(mesh) work_dir = work_dir or self.work_dir if work_dir is None: raise OLMoConfigurationError("'work_dir' must be specified.") tokenizer = tokenizer if tokenizer is not None else self.tokenizer if tokenizer is None: raise OLMoConfigurationError("'tokenizer' must be specified.") global_batch_size = ( global_batch_size if global_batch_size is not None else self.global_batch_size ) if global_batch_size is None: raise OLMoConfigurationError("'global_batch_size' must be specified.") return ComposableDataLoader( *sources, collator=collator or DataCollator( pad_token_id=tokenizer.pad_token_id, vocab_size=tokenizer.padded_vocab_size() ), tokenizer=tokenizer, work_dir=work_dir, global_batch_size=global_batch_size, dp_world_size=dist_utils.get_world_size(dp_process_group), dp_rank=dist_utils.get_rank(dp_process_group), fs_local_rank=dist_utils.get_fs_local_rank(), seed=self.seed, shuffle=self.shuffle, shuffle_strategy=self.shuffle_strategy, sources_per_epoch=self.sources_per_epoch, num_threads=self.num_threads, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, target_device_type=self.target_device_type or get_default_device().type, generate_doc_lengths=self.generate_doc_lengths, instance_filter_config=self.instance_filter_config, display_source_visualization=self.display_source_visualization, )
[docs] class ComposableDataLoader(TextDataLoaderBase): """ A data loader for composable instance sources. :param sources: One or more instance sources to draw data from. All sources must have the same ``sequence_length`` and ``max_sequence_length``. :param collator: The data collator to use to form batches. :param tokenizer: The config of the tokenizer used to create the underlying data. :param work_dir: A common local working directory that can be used for caching. :param global_batch_size: The total batch size (in tokens) across all data parallel ranks. :param dp_world_size: The number of data parallel ranks. :param dp_rank: The data parallel rank of the current process. :param fs_local_rank: The local rank of the current process with respect to filesystem access of the working directory. :param seed: The random seed to use when shuffling data. :param shuffle: Whether to shuffle data at the start of each epoch. :param shuffle_strategy: How to shuffle the data. Defaults to :data:`ShuffleStrategy.inter_source`. :param sources_per_epoch: The number of sources to use per epoch. If -1, all sources are used. :param num_threads: The number of threads to use for loading data within each worker process. :param num_workers: The number of worker processes to use for loading data. :param prefetch_factor: The number of batches to prefetch from each worker process. :param target_device_type: The type of device that batches will be sent to, typically either "cpu" or "cuda". :param generate_doc_lengths: Whether to generate document lengths for each instance needed for intra-document masking. :param instance_filter_config: Optional configuration for filtering instances based on long sequences of repeated ngrams. :param display_source_visualization: Whether to display a visualization of each source to stdout from rank 0. """ Config = ComposableDataLoaderConfig def __init__( self, *sources: InstanceSource, collator: DataCollator, tokenizer: TokenizerConfig, work_dir: PathOrStr, global_batch_size: int, dp_world_size: int = 1, dp_rank: int = 0, fs_local_rank: Optional[int] = None, seed: int = SEED_NOT_SET, shuffle: bool = True, shuffle_strategy: Optional[ShuffleStrategy] = None, sources_per_epoch: int = -1, num_threads: Optional[int] = None, num_workers: int = 0, prefetch_factor: Optional[int] = None, target_device_type: str = "cpu", generate_doc_lengths: bool = False, instance_filter_config: Optional[InstanceFilterConfig] = None, display_source_visualization: bool = True, ): if not sources: raise OLMoConfigurationError("'sources' must contain at least one InstanceSource.") if sources_per_epoch == 0 or sources_per_epoch < -1: raise OLMoConfigurationError( "'sources_per_epoch' must be -1 (for all sources) or a positive integer." ) if sources_per_epoch > 0 and len(sources) % sources_per_epoch != 0: raise OLMoConfigurationError( "'sources_per_epoch' must evenly divide into the number of sources." ) if tokenizer.pad_token_id is not None and tokenizer.pad_token_id != collator.pad_token_id: raise OLMoConfigurationError( "'tokenizer.pad_token_id' must match 'collator.pad_token_id'." ) if not shuffle and shuffle_strategy is not None: raise OLMoConfigurationError("'shuffle_strategy' cannot be set if 'shuffle' is False.") if shuffle and shuffle_strategy is None: shuffle_strategy = ShuffleStrategy.inter_source super().__init__( collator=collator, work_dir=work_dir, global_batch_size=global_batch_size, dp_world_size=dp_world_size, dp_rank=dp_rank, fs_local_rank=fs_local_rank, ) self.tokenizer = tokenizer self.seed = resolve_seed(seed) self.shuffle = shuffle self.shuffle_strategy = shuffle_strategy self.sources_per_epoch = sources_per_epoch if sources_per_epoch > 0 else len(sources) self.sources = tuple(sources) self.sequence_length = self.sources[0].sequence_length self.max_sequence_length = self.sources[0].max_sequence_length for i, source in enumerate(self.sources): if source.sequence_length != self.sequence_length: raise OLMoConfigurationError("All sources must have the same 'sequence_length'.") if source.max_sequence_length != self.max_sequence_length: raise OLMoConfigurationError( "All sources must have the same 'max_sequence_length'." ) if source.sequence_length != source.max_sequence_length: # NOTE: To guarantee the same data order with `self.max_sequence_length` fixed but `self.sequence_length` # changing, we need `len(source) // (source.max_sequence_length // source.sequence_length)` to remain constant. # For example, if `sequence_length` is half of `max_sequence_length`, then `len(source)` # should double. This check wouldn't catch all possible violations, but should catch some. if len(source) % (source.max_sequence_length // source.sequence_length) != 0: raise OLMoConfigurationError( "Each source must have a number of instances that is a multiple of " "'max_sequence_length // sequence_length' when 'sequence_length' != " "'max_sequence_length'. " f"Source {i} does not meet this condition: {source}" ) if self.max_sequence_length % self.sequence_length != 0: raise OLMoConfigurationError( "'max_sequence_length' must be a multiple of 'sequence_length'." ) if self.global_batch_size % self.sequence_length != 0: raise OLMoConfigurationError( "'global_batch_size' must be a multiple of 'sequence_length'." ) self.num_threads = num_threads self.num_workers = num_workers self.prefetch_factor = prefetch_factor self.target_device_type = target_device_type self.generate_doc_lengths = generate_doc_lengths self.instance_filter_config = instance_filter_config self._global_indices: Optional[np.ndarray] = None if display_source_visualization and dist_utils.get_rank() == 0: print() for source in sources: source.visualize() print() @property def sources_for_this_epoch(self) -> Tuple[InstanceSource, ...]: return self.sources_for_epoch(self._epoch or 1) def sources_for_epoch(self, epoch: int) -> Tuple[InstanceSource, ...]: assert self.sources_per_epoch > 0 assert len(self.sources) % self.sources_per_epoch == 0 num_groups = len(self.sources) // self.sources_per_epoch start_offset = ((epoch - 1) % num_groups) * self.sources_per_epoch return self.sources[start_offset : start_offset + self.sources_per_epoch] @ft.cached_property def source_fingerprints(self) -> Tuple[str, ...]: return tuple(source.fingerprint for source in self.sources) @property def total_instances(self) -> int: return self.instances_in_epoch(self._epoch or 1) def instances_in_epoch(self, epoch: int) -> int: sources_in_epoch = self.sources_for_epoch(epoch) if self.shuffle and self.shuffle_strategy == ShuffleStrategy.interleaved_source: # When interleaving sources, we need to make sure that each source contributes # equally to the total number of instances, so we take the minimum length # across all sources and multiply by the number of sources. min_length = min(len(source) for source in sources_in_epoch) return min_length * len(sources_in_epoch) else: return sum(len(source) for source in sources_in_epoch) @property def total_tokens(self) -> int: return self.total_instances * self.sequence_length @property def total_batches(self) -> Optional[int]: return self.batches_in_epoch(self._epoch or 1)
[docs] def batches_in_epoch(self, epoch: int) -> Optional[int]: return self.instances_in_epoch(epoch) // (self.global_batch_size // self.sequence_length)
@property def worker_info(self): return torch.utils.data.get_worker_info()
[docs] def state_dict(self) -> Dict[str, Any]: return dict( source_fingerprints=self.source_fingerprints, batches_processed=self.batches_processed, tokens_processed=self.tokens_processed, global_batch_size=self.global_batch_size, sequence_length=self.sequence_length, max_sequence_length=self.max_sequence_length, shuffle=self.shuffle, shuffle_strategy=self.shuffle_strategy, sources_per_epoch=self.sources_per_epoch, seed=self.seed, epoch=self._epoch, )
[docs] def load_state_dict(self, state_dict: Dict[str, Any]): if state_dict["sources_per_epoch"] != self.sources_per_epoch: raise RuntimeError( "Restoring data loader state with a different 'sources_per_epoch' is not supported!" ) if state_dict["source_fingerprints"] != self.source_fingerprints: # We're allowed to append more sources provided 'sources_per_epoch' hasn't changed # (checked above), the existing sources haven't changed, and we haven't already # iterated over the original set of sources. num_og_sources = len(state_dict["source_fingerprints"]) if ( num_og_sources < len(self.source_fingerprints) and state_dict["source_fingerprints"] == self.source_fingerprints[:num_og_sources] ): max_epochs = num_og_sources // self.sources_per_epoch if state_dict["epoch"] is not None and state_dict["epoch"] > max_epochs: raise RuntimeError( "Restoring data loader state after appending new sources can only be done " "when the original sources haven't been iterated over yet!" ) else: raise RuntimeError( "Restoring data loader state from different dataset source is not supported (fingerprints don't match)!" ) if state_dict["max_sequence_length"] != self.max_sequence_length: raise RuntimeError( "Restoring data loading state with a different 'max_sequence_length' is not supported!" ) if state_dict["shuffle_strategy"] != self.shuffle_strategy: raise RuntimeError( "Restoring data loading state with a different 'shuffle_strategy' is not supported!" ) if state_dict["shuffle"] != self.shuffle: raise RuntimeError( "Restoring data loading state with a different shuffle setting is not supported!" ) # Account for change in batch size / sequence length. self.tokens_processed = state_dict["tokens_processed"] self.batches_processed = self.tokens_processed // self.global_batch_size if state_dict["seed"] != self.seed and self.shuffle: log.warning( "Restoring data loading state with a different data seed, " "will use data seed from state dict for data order consistency." ) self.seed = state_dict["seed"] self._epoch = state_dict["epoch"] or self._epoch log.info( f"Data loader will resume from batch {self.batches_processed:,d}/{self.total_batches:,d} " f"based on batch size of {self.global_batch_size:,d} tokens" )
[docs] def reshuffle(self, epoch: Optional[int] = None, **kwargs): del kwargs if epoch is None: epoch = 1 if self._epoch is None else self._epoch + 1 if epoch <= 0: raise ValueError(f"'epoch' must be at least 1, got {epoch}") self._epoch = epoch self.build_and_save_global_indices()
def _iter_batches(self) -> Iterable[Dict[str, Any]]: # If we're already at the end of epoch we can skip creating the iterator. if self.total_batches is not None and self.batches_processed >= self.total_batches: yield from () return def _build_batch_iterator(): return iter( torch.utils.data.DataLoader( _IterableDataLoaderWrapper(self), batch_size=None, num_workers=self.num_workers, pin_memory=self.target_device_type == "cuda" and self.num_workers > 0, prefetch_factor=self.prefetch_factor, persistent_workers=False, timeout=0, ), ) current_global_batch_size = self.global_batch_size batch_iterator = _build_batch_iterator() while (batch := next(batch_iterator, None)) is not None: yield batch # If batch size has changed, re-initialize the workers. # NOTE: base class handles the logic of adjusting `self.batches_processed` when # `self.global_batch_size` is changed through the property setter. if current_global_batch_size != self.global_batch_size: if self.num_workers > 0: log.info("Batch size has changed, reinitializing data loading workers...") current_global_batch_size = self.global_batch_size batch_iterator = _build_batch_iterator()
[docs] def get_mock_batch(self) -> Dict[str, Any]: rng = get_rng(self.seed + self.dp_rank) num_instances = self.rank_batch_size // self.sequence_length indices = rng.integers(0, self.total_instances, num_instances) instances = [self.get_instance(idx) for idx in indices] return self.collator(instances)
def get_instance(self, idx: int) -> Dict[str, Any]: if idx < 0: idx = self.total_instances + idx source_start_offset = 0 for source in self.sources_for_this_epoch: source_end_offset = source_start_offset + len(source) if source_start_offset <= idx < source_end_offset: out: Dict[str, Any] = {"index": idx} if source.label is not None: out["metadata"] = {"source": source.label} instance = source[idx - source_start_offset] input_ids = as_tensor(instance["input_ids"]) out["input_ids"] = input_ids if (label_mask := instance.get("label_mask")) is not None: out["label_mask"] = as_tensor(label_mask) if self.generate_doc_lengths: out["doc_lens"] = get_document_lengths( input_ids, self.tokenizer.eos_token_id, bos_token_id=self.tokenizer.bos_token_id, ) if self.instance_filter_config is not None: instance_mask = True for m in find_periodic_sequences( input_ids.numpy(), max_period=self.instance_filter_config.repetition_max_period, min_period=self.instance_filter_config.repetition_min_period, ): if m.times >= self.instance_filter_config.repetition_max_count: instance_mask = False break out["instance_mask"] = instance_mask return out source_start_offset = source_end_offset raise IndexError(f"Index {idx} out of range for {self.total_instances} instances") @property def global_indices_file(self) -> Path: global_indices_fname = format_fname_from_fields( "global_indices", seed=self.seed if self.shuffle else None, epoch=self.epoch if self.shuffle else None, shuffle=self.shuffle_strategy if self.shuffle else None, size=self.total_instances, seq_len=self.sequence_length, max_seq_len=self.max_sequence_length, v=1, # tick if logic changes ) return self.work_dir / f"{global_indices_fname}.npy" def get_global_indices(self) -> np.ndarray: if self._global_indices is not None: return self._global_indices if not self.global_indices_file.is_file(): raise RuntimeError("Missing global indices file, did you forget to call 'reshuffle()'?") return np.memmap(self.global_indices_file, mode="r", dtype=np.uint32) # type: ignore def get_local_indices(self) -> np.ndarray: indices = self.get_global_indices() # Remove tail of data to make it evenly divisible. instances_per_batch = self.global_batch_size // self.sequence_length total_size = instances_per_batch * (self.total_instances // instances_per_batch) indices = indices[:total_size] # Slice up by batch. # shape: (global num batches, global num instances per batch) indices = indices.reshape(-1, instances_per_batch) # Offset by the number of batches already processed. if self.batches_processed > 0: indices = indices[self.batches_processed :] # Slice batches by data loader worker rank to avoid duplicates. if (worker_info := self.worker_info) is not None: # Note that each data loading worker gathers a whole batch at a time, and the workers # are called round-robin by rank. So to slice these up in a way that preserves order, regardless # of the number of workers, we give worker 0 the first batch, worker 1 the second batch, etc. indices = indices[worker_info.id :: worker_info.num_workers] # Finally slice batches into micro batches for the local DP rank. indices = indices[:, self.dp_rank :: self.dp_world_size].reshape((-1,)) return indices def build_and_save_global_indices(self): self._global_indices = None if self.fs_local_rank == 0: if self.global_indices_file.is_file(): log.info( f"Using existing global indices file for seed {self.seed} and epoch {self.epoch} " f"at:\n'{self.global_indices_file}'" ) else: log.info( f"Saving global data order indices for seed {self.seed} and epoch {self.epoch} " f"to:\n'{self.global_indices_file}'..." ) global_indices = self._build_global_indices() assert len(global_indices) < np.iinfo(np.uint32).max with memmap_to_write( self.global_indices_file, shape=global_indices.shape, dtype=np.uint32 ) as global_indices_mmap: global_indices_mmap[:] = global_indices log.info(f"Global data order indices saved to:\n'{self.global_indices_file}'") dist_utils.barrier() def _build_global_indices(self) -> np.ndarray: dtype = np.uint32 if not self.shuffle or self.shuffle_strategy == ShuffleStrategy.inter_source: return build_global_indices( self.total_instances, sequence_length=self.sequence_length, max_sequence_length=self.max_sequence_length, seed=(self.seed + self.epoch) if self.shuffle else None, dtype=dtype, ) elif self.shuffle_strategy == ShuffleStrategy.intra_source: indices_per_source = [] offset = 0 for i, source in enumerate(self.sources_for_this_epoch): indices = build_global_indices( len(source), sequence_length=self.sequence_length, max_sequence_length=self.max_sequence_length, seed=self.seed + self.epoch + i, dtype=dtype, ) indices_per_source.append(indices + offset) offset += len(source) return np.concatenate(indices_per_source) elif self.shuffle_strategy == ShuffleStrategy.interleaved_source: source_size = min(len(source) for source in self.sources_for_this_epoch) num_sources = len(self.sources_for_this_epoch) interleaved_indices = np.empty((num_sources * source_size,), dtype=np.uint32) offset = 0 for i, source in enumerate(self.sources_for_this_epoch): indices = build_global_indices( source_size, sequence_length=self.sequence_length, max_sequence_length=self.max_sequence_length, seed=self.seed + self.epoch + i, dtype=dtype, ) interleaved_indices[i::num_sources] = indices + offset offset += len(source) return interleaved_indices elif self.shuffle_strategy is None: raise NotImplementedError(f"Unknown shuffle strategy: {self.shuffle_strategy}") else: raise RuntimeError("shouldn't get here")
class _IterableDataLoaderWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]): def __init__(self, data_loader: ComposableDataLoader): self.data_loader = data_loader def __iter__(self) -> Iterator[Dict[str, Any]]: """ Iterate over the local rank+worker instances. """ num_threads = self.data_loader.num_threads if self.data_loader.worker_info is None and self.data_loader.num_threads is None: # If `num_threads` hasn't been specified and we're not using multiprocessing we'll # try to guess a good number of threads. num_threads = 4 # Potentially slice by threads. instance_iterator: Iterator[Dict[str, Any]] if num_threads: # In order to stay ahead of training the total queue size (sum across all threads) # should be bigger than the maximum number of instances per batch locally. max_instances_per_rank: int max_instances_per_rank = ( self.data_loader.rank_batch_size // self.data_loader.sequence_length ) queue_size = math.ceil(max_instances_per_rank * 2 / num_threads) thread_generators = [] for i in range(num_threads): # NOTE: `_get_local_instance_indices` might return an iterator, so we have to # create a unique one for each thread otherwise it would be exhausted prematurely # and give the wrong order. indices = self.data_loader.get_local_indices() generator = ( self.data_loader.get_instance(int(idx)) for idx in islice(indices, i, None, num_threads) ) thread_generators.append( threaded_generator( generator, maxsize=queue_size, thread_name=f"data thread {i}" ) ) instance_iterator = roundrobin(*thread_generators) else: indices = self.data_loader.get_local_indices() instance_iterator = (self.data_loader.get_instance(int(idx)) for idx in indices) return ( self.data_loader.collator(batch) for batch in iter_batched(instance_iterator, self.data_loader.rank_batch_size) )