Source code for olmo_core.data.composable.utils

import concurrent.futures
import warnings
from typing import Callable, List, Literal, Optional, Sequence, Type, TypeVar, Union

import numpy as np
import torch

import olmo_core.io as io
from olmo_core.aliases import PathOrStr
from olmo_core.exceptions import OLMoConfigurationError, OLMoEnvironmentError

from ..types import NumpyUIntTypes
from ..utils import get_rng


def _warmup_clients(paths: Sequence[PathOrStr]):
    # Maybe create client up front to work around a threading issue in boto.
    if any(str(p).startswith("s3://") for p in paths):
        io._get_s3_client("s3")

    if any(str(p).startswith("r2://") for p in paths):
        try:
            io._get_s3_client("r2")
        except OLMoEnvironmentError:
            # R2 might not be needed, so ignore this error. We will get an error
            # later if R2 is needed.
            pass

    if any(str(p).startswith("weka://") for p in paths):
        try:
            io._get_s3_client("weka")
        except OLMoEnvironmentError:
            # Weka might not be needed, so ignore this error. We will get an error
            # later if Weka is needed.
            pass


T = TypeVar("T")


def path_map(
    func: Callable[[PathOrStr], T],
    paths: Sequence[PathOrStr],
    *,
    max_workers: Optional[int] = None,
    method: Literal["threads", "processes"] = "threads",
) -> List[T]:
    """
    Call a function on each path, returning a list of the results, in order.

    :param func: The function to map to the paths and their indices.
    :param max_workers: The number of workers threads/processes. Set to 0 to execute synchronously
        in the main thread/process.
    :param method: Whether to use multi-threading or multi-processing.

    :returns: The results, in the same order as :data:`paths`.
    """
    if max_workers == 0 or len(paths) <= 1:
        return [func(path) for path in paths]

    executor_class: Union[
        Type[concurrent.futures.ThreadPoolExecutor],
        Type[concurrent.futures.ProcessPoolExecutor],
    ]
    if method == "threads":
        _warmup_clients(paths)
        executor_class = concurrent.futures.ThreadPoolExecutor
    elif method == "processes":
        executor_class = concurrent.futures.ProcessPoolExecutor
    else:
        raise ValueError(method)

    with executor_class(max_workers=max_workers) as executor:
        futures = [executor.submit(func, path) for path in paths]

    return [future.result() for future in futures]


def format_fname_from_fields(prefix: str, **fields) -> str:
    parts = [prefix]
    for key in sorted(fields):
        value = fields[key]
        if value is not None:
            parts.append(f"{key}{value}")
    return "_".join(parts)


def format_token_count(n: int) -> str:
    if n >= 1_000_000_000_000:
        return f"{n / 1_000_000_000_000:.1f}T"
    elif n >= 1_000_000_000:
        return f"{n / 1_000_000_000:.1f}B"
    elif n >= 1_000_000:
        return f"{n / 1_000_000:.1f}M"
    elif n >= 1_000:
        return f"{n / 1_000:.1f}K"
    else:
        return str(n)


def as_ndarray(array: Union[Sequence[int], Sequence[bool]]) -> np.ndarray:
    if isinstance(array, np.ndarray):
        return array
    elif isinstance(array, torch.Tensor):
        return array.cpu().numpy()
    else:
        return np.array(array)


def as_tensor(array: Union[Sequence[int], Sequence[bool]]) -> torch.Tensor:
    if isinstance(array, torch.Tensor):
        return array
    elif isinstance(array, np.ndarray):
        if array.dtype == np.bool_:
            return torch.tensor(array, device="cpu")
        else:
            return torch.tensor(array.astype(np.int_), dtype=torch.long, device="cpu")
    else:
        return torch.tensor(array, device="cpu")


def calculate_sample_sizes(
    source_sizes: Sequence[int],
    target_ratios: Sequence[float],
    max_repetition_factors: Sequence[float],
    target_size: Optional[int] = None,
    labels: Optional[Sequence[str]] = None,
    unit: str = "tokens",
) -> np.ndarray:
    """
    Calculate the number of items needed to sample from each source in order to match the target ratios.
    """
    assert len(source_sizes) == len(target_ratios) == len(max_repetition_factors)
    if labels is not None:
        assert len(labels) == len(source_sizes)

    ratios = np.array(target_ratios)
    sizes = np.array(source_sizes)
    max_repetition_factors_ = np.array(max_repetition_factors)

    assert (ratios > 0.0).all(), f"All ratios must be positive! Got {target_ratios}"
    assert (
        max_repetition_factors_ >= 1.0
    ).all(), f"All max repetition factors must be at least 1.0! Got {max_repetition_factors}"
    assert (sizes > 0).all(), f"All source sizes must be positive! Got {sizes}"

    strict = True
    if target_size is None:
        strict = False
        target_size = sizes.sum()

    # Normalize ratios.
    ratio_total = ratios.sum()
    if not np.allclose(ratio_total, 1.0):
        ratios = ratios / ratio_total
        new_ratio_summary_lines = []
        for i in range(len(ratios)):
            label_str: str
            if labels is not None:
                label_str = f"'{labels[i]}'"
            else:
                label_str = f"{i}"
            new_ratio_summary_lines.append(
                f" ❯ Source {label_str}: target ratio adjusted from {target_ratios[i]} to {ratios[i]}"
            )
        new_ratio_summary = "\n".join(new_ratio_summary_lines)
        warnings.warn(
            f"Target mixing ratios don't sum to 1. They will be normalized as follows:\n{new_ratio_summary}",
            UserWarning,
        )

    # Determine the number of items to sample from each source.
    # This is tricky because the sources may have different sizes, yet we want to stay
    # true to the sampling ratios while minimizing the number of dropped or over-sampled items.
    # To that end, the optimal natural distribution of items over sources is the one that
    # matches the target sampling ratios. We'll call that the 'ideal_sample_sizes'.
    ideal_sample_sizes = target_size * ratios
    # But since the actual (natural) distribution probably differs from the ideal one, it's
    # not possible to match the target ratios without some dropping or oversampling.
    # So we first calculate how much oversampling/repetition is needed per source, and then cap that
    # according to the given `max_repetitions_per_source`.
    max_repetition_factors_needed = np.maximum(ideal_sample_sizes / sizes, 1.0)
    repetition_factors_to_use = np.minimum(max_repetition_factors_, max_repetition_factors_needed)
    # Now we can adjust sizes based on the repetitions needed.
    sizes_to_use = sizes * repetition_factors_to_use
    # Lastly, we need to adjust the ideal sample sizes down until by the smallest common factor
    # that would result in all sample sizes being less than or equal to the number of items available
    # from the corresponding source. We can calculate that factor by finding the source with the
    # largest relative difference between its available size (number of items after oversampling) and
    # its ideal sample size, and taking that ratio.
    adjustment_factor = min(1.0, (sizes_to_use / ideal_sample_sizes).min())
    actual_sample_sizes = ideal_sample_sizes * adjustment_factor

    # Sanity check.
    # Sample sizes should stay true to target ratios.
    actual_ratios = actual_sample_sizes / actual_sample_sizes.sum()
    assert np.allclose(
        ratios, actual_ratios
    ), f"expected ratios: {ratios}, actual ratios: {actual_ratios}"
    # And sample sizes shouldn't be larger than the number of items available.
    actual_sample_sizes_int = actual_sample_sizes.astype(np.uint64)
    assert (actual_sample_sizes_int <= sizes_to_use).all()
    assert target_size is not None
    actual_size = actual_sample_sizes.sum()
    if strict and not np.allclose(target_size, actual_size):
        idx_of_max_diff = np.argmax(max_repetition_factors_needed - max_repetition_factors_)
        if labels is not None:
            label_str = f"with label '{labels[idx_of_max_diff]}'"
        else:
            label_str = f"with index {idx_of_max_diff}"
        required_sample_size = int(ideal_sample_sizes[idx_of_max_diff])
        provided_sample_size = int(sizes[idx_of_max_diff] * max_repetition_factors[idx_of_max_diff])
        raise OLMoConfigurationError(
            f"Unable to meet target size of {int(target_size):,d} {unit} with the given "
            f"source ratios and max repetition factors. The best we can do is {int(actual_size):,d} {unit}. "
            f"The source with the biggest discrepancy between its required sample size "
            f"(~{required_sample_size:,d} {unit}, {100 * ratios[idx_of_max_diff]:.1f}% of mix) and "
            f"the size it can provide after accounting for the max repetition factor "
            f"({sizes[idx_of_max_diff]:,d} x {max_repetition_factors[idx_of_max_diff]:.2f} ~= {provided_sample_size:,d} {unit}) "
            f"is the source {label_str}. Consider either decreasing the target size of the mix "
            f"to {int(actual_size):,d} {unit} or increasing the max repetition factor for that source "
            f"to {max_repetition_factors_needed[idx_of_max_diff]:.2f}."
        )

    return actual_sample_sizes_int


def build_global_indices(
    total_instances: int,
    *,
    sequence_length: int,
    max_sequence_length: int,
    seed: Optional[int],
    dtype: NumpyUIntTypes = np.uint32,
) -> np.ndarray:
    """
    Build global (as opposed to rank-local) instance indices as a numpy array, in a way that
    preserves the order of data when ``max_sequence_length`` is fixed but ``sequence_length`` changes.
    """
    assert total_instances < np.iinfo(dtype).max
    assert max_sequence_length % sequence_length == 0
    chunk_size = max_sequence_length // sequence_length
    # Length of dataset would be calculated incorrectly if this didn't hold.
    assert total_instances % chunk_size == 0

    # NOTE: To guarantee the same data order with `self.max_sequence_length` fixed but `self.sequence_length`
    # changing, we need `self.total_instances // chunk_size` to remain constant.
    # This is ensured by requiring `self.max_sequence_length` is a multiple of `self.sequence_length`
    # and assuming that `self.total_instances` is proportional to `chunk_size`, i.e.
    # if `self.sequence_length` is half of `self.max_sequence_length`, then `self.total_instances`
    # should double. This takes some care when implementing an `InstanceSource` to ensure that
    # excess tokens are dropped in a way that respects `self.max_sequence_length`, not `self.sequence_length`.
    chunk_indices = np.arange(total_instances // chunk_size, dtype=dtype)

    # Deterministically shuffle based on epoch and seed
    if seed is not None:
        rng = get_rng(seed)
        rng.shuffle(chunk_indices)

    if chunk_size == 1:
        return chunk_indices

    indices = np.repeat(chunk_indices * chunk_size, chunk_size)
    indices = indices.reshape((-1, chunk_size)) + np.arange(0, chunk_size).reshape((1, -1))
    indices = indices.reshape(-1)
    return indices


class _NOT_SET_INT_TYPE(int):
    pass


SEED_NOT_SET = _NOT_SET_INT_TYPE()
"""
A placeholder for the default seed, which can be changed by calling :func:`set_composable_seed()`.
"""

_SEED_RNG: Optional[np.random.Generator] = None


S = TypeVar("S", int, None, Optional[int])


def resolve_seed(default: S) -> S:
    global _SEED_RNG
    if default is SEED_NOT_SET:
        if _SEED_RNG is None:
            return 0  # type: ignore[return-type]
        else:
            return int(_SEED_RNG.integers(0, 2**31 - 1))  # type: ignore[return-type]
    else:
        return default


[docs] def set_composable_seed(seed: int): """ Set the global seed for the composable module. """ global _SEED_RNG _SEED_RNG = get_rng(seed)
[docs] def reset_composable_seed(): """ Reset the global seed for the composable module. """ global _SEED_RNG _SEED_RNG = None