Source code for olmo_core.data.composable.token_source

import functools as ft
import hashlib
import typing
from abc import abstractmethod
from dataclasses import dataclass
from typing import (
    TYPE_CHECKING,
    ClassVar,
    Iterable,
    List,
    Optional,
    Sequence,
    Tuple,
    TypedDict,
    Union,
)

import numpy as np
from typing_extensions import NotRequired

from olmo_core.aliases import PathOrStr
from olmo_core.config import Config

from ..tokenizer import TokenizerConfig
from .source_abc import SourceABC
from .utils import SEED_NOT_SET, as_ndarray, resolve_seed

if TYPE_CHECKING:
    from .sampling_document_source import (
        SamplingDocumentSource,
        SamplingDocumentSourceConfig,
    )
    from .sampling_token_source import SamplingTokenSource, SamplingTokenSourceConfig
    from .sliced_token_source import SlicedTokenSource


[docs] class TokenRange(TypedDict): """ A token range is just a dictionary that should include ``input_ids`` of the range and optionally a corresponding ``label_mask``. """ input_ids: Sequence[int] """The token IDs for the range.""" label_mask: NotRequired[Sequence[bool]] """An optional mask indicating which tokens should contribute to the loss."""
[docs] class TokenSource(SourceABC): """ An abstract base class for a source of tokens, usually consumed by an :class:`InstanceSource`. It essentially represents an array of tokens. At a minimum, a :class:`TokenSource` must implement the methods/properties (1) :meth:`num_tokens`, (2) :meth:`get_token_range`, (3) :meth:`fingerprint`, and (4) :meth:`children`. """ DISPLAY_ICON: ClassVar[str] = "\ueb7e" # Nerd Font icon for visualizations
[docs] def __len__(self) -> int: """The number of tokens available from this source, same as ``self.num_tokens``.""" return self.num_tokens
[docs] @abstractmethod def get_token_range(self, start_idx: int, end_idx: int) -> TokenRange: """ Get a range of contiguous tokens starting from ``start_idx`` (0-based, inclusive) to ``end_idx`` (exclusive). Since a :class:`TokenSource` isn't necessarily aware of document boundaries (see :class:`DocumentSource`), the token range could start in the middle of a document and span multiple documents. It's up to the consumers of a token source (e.g. an :class:`InstanceSource`) to get ranges that make sense for their use case. """ raise NotImplementedError
[docs] def __getitem__(self, key: Union[int, slice]) -> TokenRange: """ Get a range of tokens using either an integer index (for a singular token range) or a slice. """ if isinstance(key, slice): start_idx = key.start if key.start is not None else 0 end_idx = key.stop if key.stop is not None else self.num_tokens step = key.step if key.step is not None else 1 token_rng = self.get_token_range(start_idx, end_idx) out: TokenRange = {"input_ids": token_rng["input_ids"][::step]} if "label_mask" in token_rng: out["label_mask"] = token_rng["label_mask"][::step] return out else: if key < 0: key = self.num_tokens + key return self.get_token_range(key, key + 1)
def validate_indices(self, start_idx: int, end_idx: int) -> Tuple[int, int]: start_idx, end_idx = int(start_idx), int(end_idx) if start_idx < 0: start_idx = self.num_tokens + start_idx if end_idx < 0: end_idx = self.num_tokens + end_idx if end_idx == start_idx: raise ValueError( f"Invalid token range {start_idx=}{end_idx=}, ranges cannot be empty." ) if end_idx <= start_idx: raise ValueError(f"Invalid token range {start_idx=}{end_idx=}.") if start_idx >= self.num_tokens or end_idx > self.num_tokens: raise IndexError( f"Token range {start_idx=}{end_idx=} is out of bounds " f"for source {self} with {self.num_tokens:,d} tokens." ) return start_idx, end_idx
[docs] def __add__(self, other: "TokenSource") -> "ConcatenatedTokenSource": """ Add two token sources together into a :class:`ConcatenatedTokenSource` or :class:`ConcatenatedDocumentSource` depending on the type of ``self`` and ``other``. """ if isinstance(self, DocumentSource) and isinstance(other, DocumentSource): return ConcatenatedDocumentSource(self, other, work_dir=self.common_work_dir) elif isinstance(other, TokenSource): return ConcatenatedTokenSource(self, other, work_dir=self.common_work_dir) else: raise TypeError(f"Cannot add {type(self)} with {type(other)}.")
[docs] def __mul__(self, factor: float) -> "SamplingTokenSource": """Re-size this source by a given factor by sampling tokens from it.""" if isinstance(factor, (float, int)): return self.resize(factor) else: raise TypeError(f"Cannot multiply {type(self)} with {type(factor)}.")
[docs] def sample( self, *, max_tokens: int, seed: Optional[int] = SEED_NOT_SET, ) -> "SamplingTokenSource": """ Sample a contiguous chunk of tokens from this source. .. seealso:: :meth:`resize()` :param max_tokens: The maximum number of tokens to sample. :param seed: A seed to use to randomize the sampling. """ from .sampling_token_source import SamplingTokenSource return SamplingTokenSource( self, max_tokens=max_tokens, seed=seed, work_dir=self.common_work_dir, )
[docs] def resize(self, factor: float, seed: Optional[int] = SEED_NOT_SET) -> "SamplingTokenSource": """ Re-size this source by a given factor by sampling a contiguous chunk of tokens from it. .. seealso:: :meth:`sample()` :param factor: The factor to resize the source by. For example, ``0.5`` will create a source with half the number of tokens, and ``2.0`` will create a source with twice the number of tokens. :param seed: A seed to use to randomize the sampling. """ assert factor > 0 return self.sample(max_tokens=int(self.num_tokens * factor), seed=seed)
[docs] def split(self, ratio: float) -> Tuple["SlicedTokenSource", "SlicedTokenSource"]: """ Split this source into two disjoint sources according to the given ratio. :param ratio: The ratio of the first split to original source. E.g., ``0.8`` means the first split will have 80% of the tokens and the second split will have 20%. """ from .sliced_token_source import SlicedTokenSource assert 0 < ratio < 1 split_idx = int(ratio * self.num_tokens) return ( SlicedTokenSource(self, slice(0, split_idx), work_dir=self.common_work_dir), SlicedTokenSource(self, slice(split_idx, None), work_dir=self.common_work_dir), )
[docs] class InMemoryTokenSource(TokenSource): """ An in-memory implementation of a :class:`TokenSource`. Primarily meant for testing. """ DISPLAY_ICON = "\U000f035b" def __init__( self, tokens: Sequence[int], *, work_dir: PathOrStr, label_mask: Optional[Sequence[bool]] = None, label: Optional[str] = None, ): super().__init__(work_dir=work_dir, label=label) self._tokens = as_ndarray(tokens) self._label_mask = None if label_mask is None else as_ndarray(label_mask) if self._label_mask is not None: assert len(self._tokens) == len(self._label_mask) def __repr__(self) -> str: return f"{self.__class__.__name__}({self._tokens})" @ft.cached_property def fingerprint(self) -> str: sha256_hash = hashlib.sha256() sha256_hash.update((f"class={self.__class__.__name__},tokens=").encode()) sha256_hash.update(self._tokens.tobytes()) if self._label_mask is not None: sha256_hash.update(b"mask=") sha256_hash.update(self._label_mask.tobytes()) return sha256_hash.hexdigest() @property def num_tokens(self) -> int: return len(self._tokens)
[docs] def get_token_range(self, start_idx: int, end_idx: int) -> TokenRange: start_idx, end_idx = self.validate_indices(start_idx, end_idx) out: TokenRange = {"input_ids": typing.cast(Sequence[int], self._tokens[start_idx:end_idx])} if self._label_mask is not None: out["label_mask"] = typing.cast(Sequence[bool], self._label_mask[start_idx:end_idx]) return out
[docs] def children(self): return []
[docs] class DocumentSource(TokenSource): """ An abstract base class for a particular type of :class:`TokenSource` that's aware of document boundaries. This class has one additional abstract method: :meth:`get_document_offsets()`. """ DISPLAY_ICON = "\uf15c"
[docs] def sample_by_docs( self, *, max_tokens: int, seed: Optional[int] = SEED_NOT_SET, ) -> "SamplingDocumentSource": """ Sample documents from this source. .. seealso:: - :meth:`~TokenSource.sample()` - :meth:`resize_by_docs()` :param max_tokens: The maximum number of tokens to sample. :param seed: A seed to use to randomize the sampling. """ from .sampling_document_source import SamplingDocumentSource return SamplingDocumentSource( self, max_tokens=max_tokens, seed=seed, work_dir=self.common_work_dir, )
[docs] def resize_by_docs( self, factor: float, seed: Optional[int] = SEED_NOT_SET ) -> "SamplingDocumentSource": """ Re-size this source by a given factor by sampling documents from it. .. seealso:: - :meth:`~TokenSource.resize()` - :meth:`sample_by_docs()` :param factor: The factor to resize the source by. For example, ``0.5`` will create a source with half the number of tokens, and ``2.0`` will create a source with twice the number of tokens. :param seed: A seed to use to randomize the sampling. """ assert factor > 0 return self.sample_by_docs(max_tokens=int(self.num_tokens * factor), seed=seed)
[docs] @abstractmethod def get_document_offsets(self) -> Iterable[tuple[int, int]]: """Get the start (inclusive) and end (exclusive) token indices of each document, in order.""" raise NotImplementedError
[docs] class InMemoryDocumentSource(InMemoryTokenSource, DocumentSource): """ An in-memory implementation of a :class:`DocumentSource`. Primarily meant for testing. """ def __init__( self, tokens: Sequence[int], *, tokenizer: TokenizerConfig, work_dir: PathOrStr, label_mask: Optional[Sequence[bool]] = None, label: Optional[str] = None, ): super().__init__(tokens=tokens, work_dir=work_dir, label_mask=label_mask, label=label) self._tokenizer = tokenizer @property def eos_token_id(self) -> int: return self._tokenizer.eos_token_id @property def bos_token_id(self) -> Optional[int]: return self._tokenizer.bos_token_id @ft.cached_property def fingerprint(self) -> str: sha256_hash = hashlib.sha256() sha256_hash.update( ( f"class={self.__class__.__name__}," f"eos_token_id={self.eos_token_id}," f"bos_token_id={self.bos_token_id}," f"tokens=" ).encode() ) sha256_hash.update(self._tokens.tobytes()) if self._label_mask is not None: sha256_hash.update(b"mask=") sha256_hash.update(self._label_mask.tobytes()) return sha256_hash.hexdigest()
[docs] def children(self): return []
[docs] def get_document_offsets(self) -> Iterable[tuple[int, int]]: if self.bos_token_id is None: doc_boundaries = (self._tokens == self.eos_token_id).nonzero()[0] else: doc_boundaries = np.logical_and( self._tokens[:-1] == self.eos_token_id, self._tokens[1:] == self.bos_token_id ).nonzero()[0] start_idx = 0 for idx in doc_boundaries: end_idx = idx + 1 yield start_idx, end_idx start_idx = end_idx # To avoid unexpected results, we ALWAYS treat the end of the source as the end of # a document, even if it doesn't end with an EOS token ID. if start_idx != self.num_tokens: yield start_idx, self.num_tokens
[docs] @dataclass class TokenSourceConfig(Config): """A base config class for configuring and building a :class:`TokenSource`."""
[docs] @abstractmethod def build(self, work_dir: PathOrStr) -> List[TokenSource]: """Build the token source.""" raise NotImplementedError
[docs] def __add__(self, other: "TokenSourceConfig") -> "TokenSourceConfig": """ Add two token source config together into a :class:`ConcatenatedTokenSourceConfig` or :class:`ConcatenatedDocumentSourceConfig` depending on the type of ``self`` and ``other``. """ if isinstance(self, DocumentSourceConfig) and isinstance(other, DocumentSourceConfig): return ConcatenatedDocumentSourceConfig(sources=[self, other]) elif isinstance(other, TokenSourceConfig): return ConcatenatedTokenSourceConfig(sources=[self, other]) else: raise TypeError(f"Cannot add {type(self)} with {type(other)}.")
[docs] def __mul__(self, factor: float) -> "SamplingTokenSourceConfig": """Re-size this source by a given factor by sampling tokens from it.""" if isinstance(factor, (float, int)): return self.resize(factor) else: raise TypeError(f"Cannot multiply {type(self)} with {type(factor)}.")
[docs] def sample( self, *, max_tokens: int, seed: Optional[int] = SEED_NOT_SET, ) -> "SamplingTokenSourceConfig": """ Sample a contiguous chunk of tokens from this source. :param max_tokens: The maximum number of tokens to sample. :param seed: A seed to use to randomize the sampling. """ from .sampling_token_source import SamplingTokenSourceConfig return SamplingTokenSourceConfig( sources=[self], max_tokens=max_tokens, seed=resolve_seed(seed), )
[docs] def resize( self, factor: float, seed: Optional[int] = SEED_NOT_SET ) -> "SamplingTokenSourceConfig": """ Re-size this source by a given factor by sampling a contiguous chunk of tokens from it. :param factor: The factor to resize the source by. For example, ``0.5`` will create a source with half the number of tokens, and ``2.0`` will create a source with twice the number of tokens. :param seed: A seed to use to randomize the sampling. """ from .sampling_token_source import SamplingTokenSourceConfig assert factor > 0 return SamplingTokenSourceConfig( sources=[self], factor=factor, seed=resolve_seed(seed), )
[docs] def split(self, ratio: float) -> Tuple["SplitTokenSourceConfig", "SplitTokenSourceConfig"]: """ Split this source into two disjoint sources according to the given ratio. :param ratio: The ratio of the first split to original source. E.g., ``0.8`` means the first split will have 80% of the tokens and the second split will have 20%. """ return SplitTokenSourceConfig(source=self, ratio=ratio, idx=0), SplitTokenSourceConfig( source=self, ratio=ratio, idx=1 )
[docs] @dataclass class SplitTokenSourceConfig(TokenSourceConfig): """A base config class for configuring and building a split :class:`TokenSource`.""" source: TokenSourceConfig ratio: float idx: int def __post_init__(self): assert 0 < self.ratio < 1 assert self.idx in (0, 1)
[docs] def build(self, work_dir: PathOrStr) -> List["SlicedTokenSource"]: # type: ignore[override] from .sliced_token_source import SlicedTokenSource sources = self.source.build(work_dir) source = ( sources[0] if len(sources) == 1 else ConcatenatedTokenSource(*sources, work_dir=work_dir) ) split_idx = int(self.ratio * source.num_tokens) if self.idx == 0: return [SlicedTokenSource(source, slice(0, split_idx), work_dir=work_dir)] elif self.idx == 1: return [SlicedTokenSource(source, slice(split_idx, None), work_dir=work_dir)] else: raise ValueError(f"Invalid split index: {self.idx}")
[docs] @dataclass class ConcatenatedTokenSourceConfig(TokenSourceConfig): """A base config class for configuring and building a :class:`ConcatenatedTokenSource`.""" sources: List[TokenSourceConfig] label: Optional[str] = None
[docs] def build(self, work_dir: PathOrStr) -> List["ConcatenatedTokenSource"]: # type: ignore[override] sources = [ source for source_config in self.sources for source in source_config.build(work_dir) ] return [ ConcatenatedTokenSource( *sources, work_dir=work_dir, label=self.label, ) ]
[docs] class ConcatenatedTokenSource(TokenSource): """ A token source that can be created from concatenating multiple other token sources. """ Config = ConcatenatedTokenSourceConfig DISPLAY_ICON = "\uf51e" def __init__(self, *sources: TokenSource, work_dir: PathOrStr, label: Optional[str] = None): super().__init__(work_dir=work_dir, label=label) unraveled_sources: List[TokenSource] = [] for source in sources: if isinstance(source, ConcatenatedTokenSource): unraveled_sources.extend(source.sources) else: unraveled_sources.append(source) self._sources = tuple(unraveled_sources) def __repr__(self) -> str: return f"{self.__class__.__name__}{self.sources}" @property def sources(self) -> Tuple[TokenSource, ...]: return self._sources
[docs] def children(self): return self.sources
@ft.cached_property def fingerprint(self) -> str: sha256_hash = hashlib.sha256() sha256_hash.update((f"class={self.__class__.__name__},").encode()) for source in self.sources: sha256_hash.update(f"{source=},".encode()) return sha256_hash.hexdigest() @property def num_tokens(self) -> int: return sum(source.num_tokens for source in self.sources)
[docs] def get_token_range(self, start_idx: int, end_idx: int) -> TokenRange: start_idx, end_idx = self.validate_indices(start_idx, end_idx) token_chunks: List[np.ndarray] = [] mask_chunks: List[np.ndarray] = [] source_start_offset = 0 for source in self.sources: source_size = source.num_tokens source_end_offset = source_start_offset + source_size if source_start_offset <= start_idx < source_end_offset: token_rng = source.get_token_range( start_idx - source_start_offset, min(end_idx - source_start_offset, source_size) ) token_chunks.append(as_ndarray(token_rng["input_ids"])) if "label_mask" in token_rng: mask_chunks.append(as_ndarray(token_rng["label_mask"])) if end_idx - source_start_offset <= source_size: break else: start_idx = source_end_offset source_start_offset = source_end_offset else: raise IndexError(f"Failed to find tokens in range {start_idx=}{end_idx=}.") input_ids = np.concatenate(token_chunks) out: TokenRange = {"input_ids": typing.cast(Sequence[int], input_ids)} if mask_chunks: out["label_mask"] = typing.cast(Sequence[bool], np.concatenate(mask_chunks)) return out
[docs] @dataclass class DocumentSourceConfig(TokenSourceConfig): """A base config class for configuring and building a :class:`DocumentSource`."""
[docs] @abstractmethod def build(self, work_dir: PathOrStr) -> List[DocumentSource]: # type: ignore[override] """Build the document source.""" raise NotImplementedError
[docs] def sample_by_docs( self, *, max_tokens: int, seed: Optional[int] = SEED_NOT_SET, ) -> "SamplingDocumentSourceConfig": """ Sample documents from this source. .. seealso:: - :meth:`~TokenSourceConfig.sample()` - :meth:`resize_by_docs()` :param max_tokens: The maximum number of tokens to sample. :param seed: A seed to use to randomize the sampling. """ from .sampling_document_source import SamplingDocumentSourceConfig return SamplingDocumentSourceConfig( sources=[self], max_tokens=max_tokens, seed=resolve_seed(seed), )
[docs] def resize_by_docs( self, factor: float, seed: Optional[int] = SEED_NOT_SET, ) -> "SamplingDocumentSourceConfig": """ Re-size this source by a given factor by sampling documents from it. .. seealso:: - :meth:`~TokenSourceConfig.resize()` - :meth:`sample_by_docs()` :param factor: The factor to resize the source by. For example, ``0.5`` will create a source with half the number of tokens, and ``2.0`` will create a source with twice the number of tokens. :param seed: A seed to use to randomize the sampling. """ from .sampling_document_source import SamplingDocumentSourceConfig assert factor > 0 return SamplingDocumentSourceConfig( sources=[self], factor=factor, seed=resolve_seed(seed), )
[docs] @dataclass class ConcatenatedDocumentSourceConfig(DocumentSourceConfig): """A base config class for configuring and building a :class:`ConcatenatedDocumentSource`.""" sources: List[DocumentSourceConfig] label: Optional[str] = None
[docs] def build(self, work_dir: PathOrStr) -> List["ConcatenatedDocumentSource"]: # type: ignore[override] sources = [ source for source_config in self.sources for source in source_config.build(work_dir) ] return [ ConcatenatedDocumentSource( *sources, work_dir=work_dir, label=self.label, ) ]
[docs] class ConcatenatedDocumentSource(ConcatenatedTokenSource, DocumentSource): """ A document source that can be created from concatenating multiple other document sources. """ Config = ConcatenatedDocumentSourceConfig # type: ignore[assignment] def __init__(self, *sources: DocumentSource, work_dir: PathOrStr, label: Optional[str] = None): super().__init__(*sources, work_dir=work_dir, label=label) def __repr__(self) -> str: return f"{self.__class__.__name__}{self.sources}" @property def sources(self) -> Tuple[DocumentSource, ...]: return typing.cast(Tuple[DocumentSource, ...], self._sources)
[docs] def get_document_offsets(self) -> Iterable[tuple[int, int]]: start_offset = 0 for source in self.sources: source_size = source.num_tokens last_doc_end = 0 for doc_start, doc_end in source.get_document_offsets(): assert doc_start == last_doc_end # API assumes consecutive documents yield doc_start + start_offset, doc_end + start_offset last_doc_end = doc_end # To avoid unexpected results, we ALWAYS treat the end of a source file as the end of # a document, even if it doesn't end with an EOS token ID. This *should* always be the case # anyway, but just to be careful. if last_doc_end != source_size: yield last_doc_end + start_offset, source_size + start_offset start_offset += source_size
[docs] def children(self): return self.sources