import dataclasses
import functools as ft
import hashlib
import logging
import typing
import warnings
from dataclasses import dataclass
from typing import Iterable, List, Optional, Sequence
import numpy as np
import olmo_core.distributed.utils as dist_utils
import olmo_core.io as io
from olmo_core.aliases import PathOrStr
from olmo_core.exceptions import OLMoConfigurationError
from ..utils import get_rng, load_array_slice, write_array_to_disk
from .token_source import (
ConcatenatedDocumentSource,
DocumentSource,
DocumentSourceConfig,
TokenRange,
)
from .utils import SEED_NOT_SET, as_ndarray, resolve_seed
log = logging.getLogger(__name__)
[docs]
@dataclass
class SamplingDocumentSourceConfig(DocumentSourceConfig):
"""
A config for building a :class:`SamplingDocumentSource`.
"""
sources: List[DocumentSourceConfig]
max_tokens: Optional[int] = None
factor: Optional[float] = None
seed: Optional[int] = dataclasses.field(default_factory=lambda: resolve_seed(SEED_NOT_SET))
label: Optional[str] = None
def __post_init__(self):
if (self.max_tokens is None) == (self.factor is None):
raise OLMoConfigurationError("Exactly one of 'max_tokens' or 'factor' must be set.")
[docs]
def build(self, work_dir: PathOrStr) -> List["SamplingDocumentSource"]: # type: ignore[override]
sources = [s for source in self.sources for s in source.build(work_dir=work_dir)]
max_tokens = self.max_tokens
if max_tokens is None:
assert self.factor is not None
max_tokens = int(self.factor * sum(source.num_tokens for source in sources))
return [
SamplingDocumentSource(
*sources,
max_tokens=max_tokens,
seed=self.seed,
work_dir=work_dir,
label=self.label,
)
]
[docs]
class SamplingDocumentSource(DocumentSource):
"""
A document source that samples documents from other document sources.
This can be used to adjust the effective size of a source.
.. seealso::
- :class:`SamplingTokenSource`
- :class:`SamplingInstanceSource`
:param sources: The sources to sample documents from.
:param max_tokens: The maximum number of tokens to sample. The resulting source will have
at most this many tokens, but potentially less because only whole documents are sampled.
:param seed: A optional seed for sampling documents. If ``None``, no shuffling is done and
the first documents are taken up to ``max_tokens``.
.. warning::
It's recommend to set a seed to ensure that the distribution of documents in child sources
are preserved.
"""
Config = SamplingDocumentSourceConfig
DISPLAY_ICON = "\uedec"
def __init__(
self,
*sources: DocumentSource,
max_tokens: int,
seed: Optional[int] = SEED_NOT_SET,
work_dir: PathOrStr,
label: Optional[str] = None,
):
assert max_tokens > 0
if not sources:
raise ValueError("At least one source must be provided.")
super().__init__(work_dir=work_dir, label=label)
source: DocumentSource
if len(sources) > 1:
source = ConcatenatedDocumentSource(*sources, work_dir=work_dir)
else:
source = sources[0]
self._og_sources = sources
self._source = source
self._max_tokens = max_tokens
self._seed = resolve_seed(seed)
if self.seed is None:
warnings.warn(
"No seed provided for SamplingDocumentSource. "
"It's recommended to set a seed to ensure that the distribution of documents in "
"child sources are preserved."
)
# Sample tokens from the source.
log.info(f"Sampling documents from {self.source}...")
self._sampled_document_offsets_path = self.work_dir / f"{self.fingerprint}-doc-indices.npy"
self._sampled_cu_document_lens_path = self.work_dir / f"{self.fingerprint}-doc-lens.npy"
if (
not self._sampled_document_offsets_path.is_file()
or not self._sampled_cu_document_lens_path.is_file()
) and self.fs_local_rank == 0:
# Collect original document indices.
document_offsets = np.fromiter(
(idx for offsets in self.source.get_document_offsets() for idx in offsets),
dtype=np.uint64,
).reshape(-1, 2)
# Maybe shuffle OG doc indices.
if self.seed is not None:
rng = get_rng(self.seed)
rng.shuffle(document_offsets, axis=0)
# Find cumulative token counts, then repeat/truncate OG docs to get the target max number of tokens.
document_lengths = document_offsets[:, 1] - document_offsets[:, 0]
cu_document_lengths = np.cumsum(document_lengths, dtype=np.uint64)
total_tokens = int(cu_document_lengths[-1])
n_repetitions = max_tokens // total_tokens
remaining_sample_size = max_tokens % total_tokens
sampled_document_offsets = np.take(
document_offsets,
(cu_document_lengths <= remaining_sample_size).nonzero()[0],
axis=0,
)
if n_repetitions > 0:
sampled_document_offsets = np.concatenate(
[
np.tile(document_offsets, (n_repetitions, 1)),
sampled_document_offsets,
]
)
if sampled_document_offsets.shape[0] == 0:
raise RuntimeError(f"Unable to sample {self.max_tokens} tokens from {self.source}")
# Now get the cumulative lengths of the sampled documents.
sampled_document_lengths = (
sampled_document_offsets[:, 1] - sampled_document_offsets[:, 0]
)
sampled_cu_document_lengths = np.concatenate(
[
np.array([0], dtype=np.uint64),
np.cumsum(sampled_document_lengths, dtype=np.uint64),
]
)
# Write to disk.
write_array_to_disk(
sampled_document_offsets.reshape(-1), self._sampled_document_offsets_path
)
write_array_to_disk(sampled_cu_document_lengths, self._sampled_cu_document_lens_path)
dist_utils.barrier()
@property
def source(self) -> DocumentSource:
return self._source
@property
def max_tokens(self) -> int:
return self._max_tokens
@ft.cached_property
def num_docs(self) -> int:
return io.get_file_size(self._sampled_document_offsets_path) // np.uint64(0).itemsize // 2
@ft.cached_property
def num_tokens(self) -> int:
return int(
load_array_slice(
self._sampled_cu_document_lens_path, self.num_docs, self.num_docs + 1, np.uint64
)[0]
)
@property
def seed(self) -> Optional[int]:
return self._seed
@ft.cached_property
def fingerprint(self) -> str:
sha256_hash = hashlib.sha256()
sha256_hash.update(
(
f"class={self.__class__.__name__},"
f"source={self.source.fingerprint},"
f"max_tokens={self.max_tokens},"
f"seed={self.seed},"
).encode()
)
return sha256_hash.hexdigest()
[docs]
def get_token_range(self, start_idx: int, end_idx: int) -> TokenRange:
start_idx, end_idx = self.validate_indices(start_idx, end_idx)
# NOTE: we need to map this range to ranges of tokens in the original source.
# We do that by mapping this range to a range of documents, then getting the right range
# of tokens from within each document.
# Load cumulative document lengths for the sampled documents.
cu_doc_lens = np.memmap(self._sampled_cu_document_lens_path, mode="r", dtype=np.uint64)
# Get the document indices (with respect to the local sample) that encompasses the token range.
doc_indices_in_sample = np.logical_and(
(cu_doc_lens > start_idx)[1:], (cu_doc_lens[:-1] < end_idx)
).nonzero()[0]
starting_doc, ending_doc = int(doc_indices_in_sample[0]), int(doc_indices_in_sample[-1])
# Now load the corresponding offsets of the original documents.
og_doc_offsets = load_array_slice(
self._sampled_document_offsets_path, 2 * starting_doc, 2 * (ending_doc + 1), np.uint64
).reshape(-1, 2)
# Finally, we iterate over the OG document offsets and load the corresponding ranges.
document_input_ids: List[np.ndarray] = []
document_label_masks: List[np.ndarray] = []
tokens_remaining = end_idx - start_idx
for doc_idx, (doc_start, doc_end) in enumerate(og_doc_offsets):
if doc_idx == 0:
doc_start += start_idx - int(cu_doc_lens[starting_doc])
token_rng = self.source.get_token_range(
int(doc_start), min(int(doc_end), doc_start + tokens_remaining)
)
document_input_ids.append(as_ndarray(token_rng["input_ids"]))
if "label_mask" in token_rng:
document_label_masks.append(as_ndarray(token_rng["label_mask"]))
tokens_remaining -= document_input_ids[-1].size
# Combine token IDs and maybe label masks for each document.
input_ids = np.concatenate(document_input_ids)
out: TokenRange = {"input_ids": typing.cast(Sequence[int], input_ids)}
if document_label_masks:
out["label_mask"] = typing.cast(Sequence[bool], np.concatenate(document_label_masks))
return out
[docs]
def get_document_offsets(self) -> Iterable[tuple[int, int]]:
cu_doc_lens = np.memmap(self._sampled_cu_document_lens_path, mode="r", dtype=np.uint64)
start_offset = 0
for cu_doc_len in cu_doc_lens[1:]:
yield (start_offset, int(cu_doc_len))
start_offset = int(cu_doc_len)
[docs]
def children(self):
return self._og_sources