Source code for olmo_core.data.composable.sampling_token_source

import dataclasses
import functools as ft
import hashlib
import logging
import typing
from collections import deque
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple

import numpy as np

from olmo_core.aliases import PathOrStr
from olmo_core.exceptions import OLMoConfigurationError

from ..utils import get_rng
from .token_source import TokenRange, TokenSource, TokenSourceConfig
from .utils import SEED_NOT_SET, as_ndarray, resolve_seed

log = logging.getLogger(__name__)


[docs] @dataclass class SamplingTokenSourceConfig(TokenSourceConfig): """ A config for building a :class:`SamplingTokenSource`. """ sources: List[TokenSourceConfig] max_tokens: Optional[int] = None factor: Optional[float] = None seed: Optional[int] = dataclasses.field(default_factory=lambda: resolve_seed(SEED_NOT_SET)) label: Optional[str] = None def __post_init__(self): if (self.max_tokens is None) == (self.factor is None): raise OLMoConfigurationError("Exactly one of 'max_tokens' or 'factor' must be set.")
[docs] def build(self, work_dir: PathOrStr) -> List["SamplingTokenSource"]: # type: ignore[override] sources = [s for source in self.sources for s in source.build(work_dir=work_dir)] max_tokens = self.max_tokens if max_tokens is None: assert self.factor is not None max_tokens = int(self.factor * sum(source.num_tokens for source in sources)) return [ SamplingTokenSource( *sources, max_tokens=max_tokens, seed=self.seed, work_dir=work_dir, label=self.label, ) ]
[docs] class SamplingTokenSource(TokenSource): """ A token source that samples contiguous chunks of tokens from other token sources. This can be used to adjust the effective size of a source. .. tip:: Unlike :class:`SamplingDocumentSource`, this class doesn't take document boundaries into account when sampling, but is much faster to set up. :param sources: The sources to sample tokens from. :param max_tokens: The maximum number of tokens to sample. :param seed: A optional seed for sampling. If ``None``, the first ``N_s`` tokens are taken from each source where ``N_s`` is proportional to the size of the source. .. warning:: Generally you should prefer to use :class:`SamplingDocumentSource` with random sampling (a seed provided) to preserve the distribution of child sources. This is a quick and dirty alternatively. """ Config = SamplingTokenSourceConfig DISPLAY_ICON = "\uedec" def __init__( self, *sources: TokenSource, max_tokens: int, seed: Optional[int] = SEED_NOT_SET, work_dir: PathOrStr, label: Optional[str] = None, ): from .mixing_token_source import MixingTokenSource from .sliced_token_source import SlicedTokenSource if not sources: raise ValueError("At least one source must be provided.") assert max_tokens > 0 super().__init__(work_dir=work_dir, label=label) # Determine how many tokens to sample from each source. # NOTE: We do our best to "unwind" child sources that are mixing or sampling sources # themselves in order to preserve the distribution of their child sources. # But this doesn't cover all edge cases, so it's generally better to do sampling/mixing # at the document level. frontier = deque(sources) total_tokens = sum(source.num_tokens for source in sources) seed = resolve_seed(seed) rng = None if seed is None else get_rng(seed) final_sources: List[TokenSource] = [] while frontier: source = frontier.popleft() # Unwind any mixing token sources into their sampling token sources, # and sampling token sources into their children so that we sample directly from each # child directly in order to maintain the desired ratios. if isinstance(source, MixingTokenSource): frontier.extend(source.sampled_sources) elif isinstance(source, SamplingTokenSource): frontier.extend(source.sources) else: # Determine how many tokens to sample from source while keeping the same # ratios between sources. For example, suppose source A makes up 75% of the # `total_tokens` available across all sources. Then we want the number of tokens # we sample from A to make up 75% of `max_tokens`. In other words, we want # `len(source) / total_tokens ~= source_sample_size / max_tokens`, # so `source_sample_size = max_tokens * (source.num_tokens / total_tokens)`. source_sample_size = round(max_tokens * (len(source) / total_tokens)) # Determine number of repetitions and sampling start/end offsets for each source. n_repetitions = source_sample_size // source.num_tokens final_sources.extend([source] * n_repetitions) remaining_sample_size = source_sample_size % source.num_tokens if remaining_sample_size > 0: start_idx = ( 0 if rng is None else rng.integers(0, source.num_tokens - remaining_sample_size) ) end_idx = start_idx + remaining_sample_size final_sources.append( SlicedTokenSource(source, slice(start_idx, end_idx), work_dir=self.work_dir) ) self._og_sources = sources self._sources = tuple(final_sources) self._seed = seed @property def sources(self) -> Tuple[TokenSource, ...]: return self._sources @ft.cached_property def num_tokens(self) -> int: return sum(source.num_tokens for source in self.sources) @ft.cached_property def fingerprint(self) -> str: sha256_hash = hashlib.sha256() sha256_hash.update((f"class={self.__class__.__name__},").encode()) for source in self.sources: sha256_hash.update(f"source={source.fingerprint},".encode()) return sha256_hash.hexdigest() @property def seed(self) -> Optional[int]: return self._seed
[docs] def get_token_range(self, start_idx: int, end_idx: int) -> TokenRange: start_idx, end_idx = self.validate_indices(start_idx, end_idx) token_chunks: List[np.ndarray] = [] mask_chunks: List[np.ndarray] = [] source_start_offset = 0 for source in self.sources: source_end_offset = source_start_offset + len(source) if source_start_offset <= start_idx < source_end_offset: token_rng = source.get_token_range( start_idx - source_start_offset, min(end_idx - source_start_offset, len(source)), ) token_chunks.append(as_ndarray(token_rng["input_ids"])) if "label_mask" in token_rng: mask_chunks.append(as_ndarray(token_rng["label_mask"])) if end_idx <= source_end_offset: break else: start_idx = source_end_offset source_start_offset = source_end_offset input_ids = np.concatenate(token_chunks) out: TokenRange = {"input_ids": typing.cast(Sequence[int], input_ids)} if mask_chunks: out["label_mask"] = typing.cast(Sequence[bool], np.concatenate(mask_chunks)) return out
[docs] def children(self): return self._og_sources