import functools as ft
import hashlib
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generator, List, Optional, Sequence, Tuple, TypedDict
from typing_extensions import NotRequired
import olmo_core.io as io
from olmo_core.aliases import PathOrStr
from olmo_core.config import Config
from olmo_core.exceptions import OLMoConfigurationError
from .source_abc import SourceABC
from .utils import SEED_NOT_SET, resolve_seed
if TYPE_CHECKING:
from .sampling_instance_source import (
SamplingInstanceSource,
SamplingInstanceSourceConfig,
)
from .sliced_instance_source import SlicedInstanceSource
[docs]
class Instance(TypedDict):
"""
An instance is just a dictionary that should include ``input_ids`` and optionally a
corresponding ``label_mask``.
"""
input_ids: Sequence[int]
"""The token IDs for this instance."""
label_mask: NotRequired[Sequence[bool]]
"""An optional mask indicating which tokens should contribute to the loss."""
[docs]
class InstanceSource(SourceABC):
"""
An abstract base class for a source of instances, usually consumed by a :class:`ComposableDataLoader`.
It essentially represents an array of instances, where each instance is a sequence of
``sequence_length`` tokens.
:param sequence_length: The length of each sequence (instance) to produce.
:param max_sequence_length: For sources that support this. If you intend to increase the sequence
length in the middle of an epoch, you should set this to the maximum sequence length that you'll
train on to guarantee that you can restart the run with the same data order after changing sequence length.
Care needs to be taken when implementing this in a subclass to ensure that the exact same tokens
will be produced when `sequence_length` is changed but `max_sequence_length` is fixed.
"""
def __init__(
self,
*,
work_dir: PathOrStr,
sequence_length: int,
max_sequence_length: Optional[int] = None,
label: Optional[str] = None,
):
super().__init__(work_dir=work_dir, label=label)
if io.is_url(work_dir):
raise OLMoConfigurationError(
f"'work_dir' should be a local path, not a URL ('{work_dir}')."
)
assert sequence_length > 0
if max_sequence_length is not None:
assert max_sequence_length > 0
if sequence_length > max_sequence_length:
raise OLMoConfigurationError(
"'sequence_length' cannot be greater than 'max_sequence_length'."
)
if max_sequence_length % sequence_length != 0:
raise OLMoConfigurationError(
"'max_sequence_length' must be a multiple of 'sequence_length'."
)
self._sequence_length = sequence_length
self._max_sequence_length = max_sequence_length or sequence_length
@property
def sequence_length(self) -> int:
"""The sequence length of each instance that this source will produce."""
return self._sequence_length
@property
def max_sequence_length(self) -> int:
"""
Typically the same as ``sequence_length`` though in some cases it can be greater, such
as when the sequence length will be increased in the middle of an epoch.
"""
return self._max_sequence_length
@property
def num_tokens(self) -> int:
return len(self) * self.sequence_length
[docs]
@abstractmethod
def __len__(self) -> int:
"""The number of instances available from this source."""
raise NotImplementedError
[docs]
@abstractmethod
def __getitem__(self, idx: int) -> Instance:
"""Get an instance by index."""
raise NotImplementedError
[docs]
def __iter__(self) -> Generator[Instance, None, None]:
"""Iterate over all instances in the source."""
for i in range(len(self)):
yield self[i]
def validate_index(self, idx: int) -> int:
idx = int(idx)
if idx < 0:
idx = len(self) + idx
if not (0 <= idx < len(self)):
raise IndexError(
f"Index {idx} is out of bounds for source {self} with {len(self):,d} instances."
)
return idx
[docs]
def __add__(self, other: "InstanceSource") -> "ConcatenatedInstanceSource":
"""Add two instance sources together into a :class:`ConcatenatedInstanceSource`."""
if isinstance(other, InstanceSource):
return ConcatenatedInstanceSource(self, other, work_dir=self.common_work_dir)
else:
raise TypeError(f"Cannot add {type(self)} with {type(other)}.")
[docs]
def __mul__(self, factor: float) -> "SamplingInstanceSource":
"""Re-size this source by a given factor by sampling instances from it."""
if isinstance(factor, (float, int)):
return self.resize(factor)
else:
raise TypeError(f"Cannot multiply {type(self)} with {type(factor)}.")
[docs]
def sample(
self,
*,
max_tokens: Optional[int] = None,
max_instances: Optional[int] = None,
seed: Optional[int] = SEED_NOT_SET,
) -> "SamplingInstanceSource":
"""
Sample instances from this source.
.. seealso::
- :meth:`resize()`
- :meth:`split()`
:param max_tokens: The maximum number of tokens to sample from this source.
Mutually exclusive with ``max_instances``.
:param max_instances: The maximum number of instances to sample from this source.
Mutually exclusive with ``max_tokens``.
:param seed: A random seed for sampling. If ``None``, no shuffling is done and instances
are taken in order.
"""
from .sampling_instance_source import SamplingInstanceSource
return SamplingInstanceSource(
self,
max_tokens=max_tokens,
max_instances=max_instances,
seed=resolve_seed(seed),
work_dir=self.common_work_dir,
)
[docs]
def resize(self, factor: float, seed: Optional[int] = SEED_NOT_SET) -> "SamplingInstanceSource":
"""
Re-size this source by a given factor by sampling instances from it.
.. seealso::
- :meth:`sample()`
- :meth:`split()`
:param factor: The factor by which to resize this source.
:param seed: A random seed for sampling.
"""
assert factor > 0
return self.sample(max_tokens=int(self.num_tokens * factor), seed=seed)
[docs]
def split(
self, ratio: float, seed: Optional[int] = None
) -> Tuple["SlicedInstanceSource", "SlicedInstanceSource"]:
"""
Split this source into two disjoint sources according to the given ratio.
:param ratio: The ratio of the first split to original source. E.g., ``0.8`` means
the first split will have 80% of the instances and the second split will have 20%.
:param seed: A seed to use to randomize the split.
"""
from .sliced_instance_source import SlicedInstanceSource
assert 0 < ratio < 1
split_idx = int(
((ratio * self.num_tokens) // self.max_sequence_length)
* (self.max_sequence_length // self.sequence_length)
)
return (
SlicedInstanceSource(
self, slice(0, split_idx), seed=seed, work_dir=self.common_work_dir
),
SlicedInstanceSource(
self, slice(split_idx, None), seed=seed, work_dir=self.common_work_dir
),
)
[docs]
def random_split(
self, ratio: float, seed: int = SEED_NOT_SET
) -> Tuple["SlicedInstanceSource", "SlicedInstanceSource"]:
"""
Like :meth:`split()` but always a random split.
"""
return self.split(ratio, seed=seed)
[docs]
def visualize(self, icons: bool = True):
"""
Print a visualization of this source and its children, recursively.
:param icons: Whether to use icons in the visualization.
.. important::
Some icons used in the visualization require a Nerd Font to render properly.
"""
from .visualize import visualize_source
visualize_source(self, icons=icons)
[docs]
@dataclass
class InstanceSourceConfig(Config):
"""A base config class for configuring and building an :class:`InstanceSource`."""
[docs]
@abstractmethod
def build(self, work_dir: PathOrStr) -> InstanceSource:
"""Build the :class:`InstanceSource`."""
raise NotImplementedError
[docs]
def __add__(self, other: "InstanceSourceConfig") -> "ConcatenatedInstanceSourceConfig":
"""Add two instance source configs together into a :class:`ConcatenatedInstanceSourceConfig`."""
if isinstance(other, InstanceSourceConfig):
return ConcatenatedInstanceSourceConfig(sources=[self, other])
else:
raise TypeError(f"Cannot add {type(self)} with {type(other)}.")
[docs]
def __mul__(self, factor: float) -> "SamplingInstanceSourceConfig":
"""Re-size this source by a given factor by sampling instances from it."""
if isinstance(factor, (float, int)):
return self.resize(factor)
else:
raise TypeError(f"Cannot multiply {type(self)} with {type(factor)}.")
[docs]
def sample(
self,
*,
max_tokens: Optional[int] = None,
max_instances: Optional[int] = None,
seed: Optional[int] = SEED_NOT_SET,
) -> "SamplingInstanceSourceConfig":
"""
Sample instances from this source.
:param max_tokens: The maximum number of tokens to sample from this source.
Mutually exclusive with ``max_instances``.
:param max_instances: The maximum number of instances to sample from this source.
Mutually exclusive with ``max_tokens``.
:param seed: A random seed for sampling. If ``None``, no shuffling is done and instances
are taken in order.
"""
from .sampling_instance_source import SamplingInstanceSourceConfig
return SamplingInstanceSourceConfig(
sources=[self],
max_tokens=max_tokens,
max_instances=max_instances,
seed=resolve_seed(seed),
)
[docs]
def resize(
self, factor: float, seed: Optional[int] = SEED_NOT_SET
) -> "SamplingInstanceSourceConfig":
"""
Re-size this source by a given factor by sampling instances from it.
:param factor: The factor by which to resize this source.
:param seed: A random seed for sampling.
"""
from .sampling_instance_source import SamplingInstanceSourceConfig
assert factor > 0
return SamplingInstanceSourceConfig(
sources=[self],
factor=factor,
seed=resolve_seed(seed),
)
[docs]
def split(
self, ratio: float, seed: Optional[int] = None
) -> Tuple["SplitInstanceSourceConfig", "SplitInstanceSourceConfig"]:
"""
Split this source into two disjoint sources according to the given ratio.
:param ratio: The ratio of the first split to original source. E.g., ``0.8`` means
the first split will have 80% of the instances and the second split will have 20%.
:param seed: A seed to use to randomize the split.
"""
seed = resolve_seed(seed)
return SplitInstanceSourceConfig(
source=self,
ratio=ratio,
idx=0,
seed=seed,
), SplitInstanceSourceConfig(source=self, ratio=ratio, idx=1, seed=seed)
[docs]
def random_split(
self, ratio: float, seed: int = SEED_NOT_SET
) -> Tuple["SplitInstanceSourceConfig", "SplitInstanceSourceConfig"]:
"""
Like :meth:`split()` but always a random split.
"""
return self.split(ratio, seed=seed)
[docs]
@dataclass
class SplitInstanceSourceConfig(InstanceSourceConfig):
"""A base config class for configuring and building a split :class:`InstanceSource`."""
source: InstanceSourceConfig
ratio: float
idx: int
seed: Optional[int] = None
def __post_init__(self):
assert 0 < self.ratio < 1
assert self.idx in (0, 1)
self.seed = resolve_seed(self.seed)
[docs]
def build(self, work_dir: PathOrStr) -> InstanceSource:
from .sliced_instance_source import SlicedInstanceSource
source = self.source.build(work_dir)
split_idx = int(self.ratio * len(source))
seed = resolve_seed(self.seed)
if self.idx == 0:
return SlicedInstanceSource(source, slice(0, split_idx), seed=seed, work_dir=work_dir)
elif self.idx == 1:
return SlicedInstanceSource(
source, slice(split_idx, None), seed=seed, work_dir=work_dir
)
else:
raise ValueError(f"Invalid split index: {self.idx}")
[docs]
@dataclass
class ConcatenatedInstanceSourceConfig(InstanceSourceConfig):
"""A config for a :class:`ConcatenatedInstanceSource`."""
sources: List[InstanceSourceConfig]
[docs]
def build(self, work_dir: PathOrStr) -> "ConcatenatedInstanceSource":
return ConcatenatedInstanceSource(
*[source.build(work_dir=work_dir) for source in self.sources],
work_dir=work_dir,
)
[docs]
class ConcatenatedInstanceSource(InstanceSource):
"""
An instance source that concatenates multiple instance sources together end-to-end.
"""
Config = ConcatenatedInstanceSourceConfig
DISPLAY_ICON = "\uf51e"
def __init__(
self,
*sources: InstanceSource,
work_dir: PathOrStr,
label: Optional[str] = None,
):
if len(sources) == 0:
raise OLMoConfigurationError("At least one source must be provided.")
sequence_length = sources[0].sequence_length
max_sequence_length = sources[0].max_sequence_length
for source in sources:
if source.sequence_length != sequence_length:
raise OLMoConfigurationError("All sources must have the same sequence length.")
if source.max_sequence_length != max_sequence_length:
raise OLMoConfigurationError("All sources must have the same max sequence length.")
super().__init__(
work_dir=work_dir,
sequence_length=sequence_length,
max_sequence_length=max_sequence_length,
label=label,
)
unraveled_sources: List[InstanceSource] = []
for source in sources:
if isinstance(source, ConcatenatedInstanceSource):
unraveled_sources.extend(source.sources)
else:
unraveled_sources.append(source)
self._sources = tuple(unraveled_sources)
@property
def sources(self) -> Tuple[InstanceSource, ...]:
return self._sources
@ft.cached_property
def fingerprint(self) -> str:
sha256_hash = hashlib.sha256()
sha256_hash.update((f"class={self.__class__.__name__},").encode())
for source in self.sources:
sha256_hash.update(f"source={source.fingerprint},".encode())
return sha256_hash.hexdigest()
[docs]
def __len__(self) -> int:
return sum(len(source) for source in self.sources)
[docs]
def __getitem__(self, idx: int) -> Instance:
idx = self.validate_index(idx)
source_start_offset = 0
for source in self.sources:
source_end_offset = source_start_offset + len(source)
if source_start_offset <= idx < source_end_offset:
return source[idx - source_start_offset]
source_start_offset = source_end_offset
raise IndexError(f"{idx} is out of bounds for source of size {len(self)}")
[docs]
def children(self):
return self.sources