Source code for olmo_core.data.composable.sampling_instance_source

import dataclasses
import functools as ft
import hashlib
import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple

import numpy as np

import olmo_core.distributed.utils as dist_utils
from olmo_core.aliases import PathOrStr
from olmo_core.exceptions import OLMoConfigurationError

from ..utils import load_array_slice, write_array_to_disk
from .instance_source import Instance, InstanceSource, InstanceSourceConfig
from .utils import SEED_NOT_SET, build_global_indices, resolve_seed


[docs] @dataclass class SamplingInstanceSourceConfig(InstanceSourceConfig): """Config for :class:`SamplingInstanceSource`.""" sources: List[InstanceSourceConfig] max_tokens: Optional[int] = None max_instances: Optional[int] = None factor: Optional[float] = None seed: Optional[int] = dataclasses.field(default_factory=lambda: resolve_seed(SEED_NOT_SET)) label: Optional[str] = None def __post_init__(self): if ( sum( [ (self.max_tokens is not None), (self.max_instances is not None), (self.factor is not None), ] ) != 1 ): raise OLMoConfigurationError( "Either 'max_tokens', 'max_instances', or 'factor' must be set, but not more than one." )
[docs] def build(self, work_dir: PathOrStr) -> "SamplingInstanceSource": sources = [source.build(work_dir) for source in self.sources] max_tokens = self.max_tokens max_instances = self.max_instances if max_tokens is None and max_instances is None: assert self.factor is not None max_tokens = int(self.factor * sum(source.num_tokens for source in sources)) return SamplingInstanceSource( *sources, max_tokens=max_tokens, max_instances=max_instances, work_dir=work_dir, seed=self.seed, label=self.label, )
[docs] class SamplingInstanceSource(InstanceSource): """ An instance source that samples instances from other instance sources. This can be used to adjust the effective size of a source. .. seealso:: - :class:`SamplingTokenSource` - :class:`SamplingDocumentSource` :param sources: The sources to sample instances from. :param max_tokens: The maximum number of tokens to sample. Alternatively you can specify ``max_instances``. :param max_instances: The maximum number of instances to sample. Mutually exclusive with ``max_tokens``. :param seed: A optional seed for sampling. If ``None``, the first ``N_s`` instances are taken from each source where ``N_s`` is proportional to the size of the source. .. warning:: It's recommend to set a seed to ensure that the distribution of instances in child sources are preserved. """ Config = SamplingInstanceSourceConfig DISPLAY_ICON = "\uedec" def __init__( self, *sources: InstanceSource, max_tokens: Optional[int] = None, max_instances: Optional[int] = None, work_dir: PathOrStr, seed: Optional[int] = SEED_NOT_SET, label: Optional[str] = None, ): if not sources: raise OLMoConfigurationError("At least one source must be provided.") sequence_length = sources[0].sequence_length max_sequence_length = sources[0].max_sequence_length if (max_tokens is None) == (max_instances is None): raise OLMoConfigurationError( "Either max_tokens or max_instances must be set, but not both." ) elif max_tokens is not None: assert max_tokens > 0 max_instances = max_tokens // sequence_length elif max_instances is not None: assert max_instances > 0 assert max_instances is not None super().__init__( work_dir=work_dir, sequence_length=sequence_length, max_sequence_length=max_sequence_length, label=label, ) self._og_sources = sources self._max_instances = max_instances self._seed = resolve_seed(seed) self._dtype = np.uint32 if self.seed is None: warnings.warn( "No seed provided for SamplingInstanceSource. " "It's recommended to set a seed to ensure that the distribution of instances in " "child sources are preserved." ) # Determine how many instances to sample from each source. total_instances = sum(len(source) for source in sources) chunk_size = self.max_sequence_length // self.sequence_length source_sample_sizes: List[int] = [] for source in sources: if source.sequence_length != sequence_length: raise OLMoConfigurationError("All sources must have the same sequence length.") if source.max_sequence_length != max_sequence_length: raise OLMoConfigurationError("All sources must have the same max sequence length.") # We want `len(source) / total_instances ~= source_sample_size / max_instances`, # so `source_sample_size = max_instances * (len(source) / total_instances)`. sample_size = int(max_instances * (len(source) / total_instances)) # Adjust to be a multiple of chunk_size. sample_size = chunk_size * (sample_size // chunk_size) source_sample_sizes.append(sample_size) self._sources = sources self._source_sample_sizes = tuple(source_sample_sizes) # Sample indices from each source. source_sample_paths: List[PathOrStr] = [] for i, (source, sample_size) in enumerate(zip(self.sources, source_sample_sizes)): source_sample_path = ( self.work_dir / f"{self.fingerprint}-{source.fingerprint}-indices.npy" ) source_sample_paths.append(source_sample_path) if self.fs_local_rank == 0: n_repetitions = sample_size // len(source) remaining_sample_size = sample_size % len(source) source_indices = build_global_indices( len(source), sequence_length=self.sequence_length, max_sequence_length=self.max_sequence_length, seed=None if self.seed is None else self.seed + i, dtype=self._dtype, ) sample_indices = source_indices[:remaining_sample_size] source_sample_indices = np.concatenate( [np.tile(source_indices, n_repetitions), sample_indices] ) write_array_to_disk(source_sample_indices, source_sample_path) self._source_sample_paths = tuple(source_sample_paths) dist_utils.barrier() @property def sources(self) -> Tuple[InstanceSource, ...]: return self._sources @property def max_instances(self) -> int: return self._max_instances @property def seed(self) -> Optional[int]: return self._seed @property def source_sample_sizes(self) -> Tuple[int, ...]: return self._source_sample_sizes @ft.cached_property def num_instances(self) -> int: return sum(self.source_sample_sizes) @ft.cached_property def fingerprint(self) -> str: sha256_hash = hashlib.sha256() sha256_hash.update((f"class={self.__class__.__name__},{self.seed=},").encode()) for source, sample_size in zip(self.sources, self.source_sample_sizes): chunk_size = self.max_sequence_length // self.sequence_length sample_size_chunk_size_ratio = sample_size // chunk_size sha256_hash.update( f"source={source.fingerprint},{sample_size_chunk_size_ratio=}".encode() ) return sha256_hash.hexdigest()
[docs] def children(self): return self._og_sources
[docs] def __len__(self) -> int: return self.num_instances
[docs] def __getitem__(self, idx: int) -> Instance: idx = self.validate_index(idx) source_start_offset = 0 for source, source_sample_size, source_sample_indices_path in zip( self.sources, self.source_sample_sizes, self._source_sample_paths ): source_end_offset = source_start_offset + source_sample_size if source_start_offset <= idx < source_end_offset: idx_in_source = load_array_slice( source_sample_indices_path, idx - source_start_offset, idx - source_start_offset + 1, self._dtype, )[0] return source[int(idx_in_source)] source_start_offset = source_end_offset raise IndexError(f"{idx} is out of bounds for source of size {len(self)}")