import functools as ft
import gzip
import math
import os
import random
from collections import defaultdict, deque
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import islice
from pathlib import Path
from typing import (
Any,
Dict,
Generator,
Iterable,
List,
NamedTuple,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)
import numpy as np
import torch
import torch.nn.functional as F
from olmo_core.aliases import PathOrStr
from olmo_core.io import (
add_cached_path_clients,
get_bytes_range,
get_file_size,
is_url,
resource_path,
)
from olmo_core.utils import capped_powers_of_2
from .types import LongDocStrategy
[docs]
def split_batch(batch: Dict[str, Any], num_microbatch_instances: int) -> List[Dict[str, Any]]:
"""
Split a batch (such as one generated by the :class:`DataCollator`) into a list of micro-batches.
"""
if num_microbatch_instances <= 0:
raise RuntimeError("microbatch size is too small!")
batch_size = batch["input_ids"].shape[0]
if batch_size <= num_microbatch_instances:
return [batch]
else:
micro_batches = {}
for key, value in batch.items():
if isinstance(value, torch.Tensor):
micro_batches[key] = value.split(num_microbatch_instances, dim=0)
elif isinstance(value, list):
micro_batches[key] = [
value[
num_microbatch_instances * i : num_microbatch_instances * i
+ num_microbatch_instances
]
for i in range(math.ceil(batch_size / num_microbatch_instances))
]
else:
raise RuntimeError(f"unexpected item in batch: '{key}={value}'")
return [
{key: value[i] for key, value in micro_batches.items()}
for i in range(len(micro_batches["input_ids"]))
]
[docs]
def melt_batch(batch: Dict[str, Any], target_sequence_length: int) -> Dict[str, Any]:
"""
"Melts" a batch by shortening the sequence length and proportionally increasing the number
of instances.
"""
current_batch_size, current_sequence_length = batch["input_ids"].shape
if current_sequence_length <= target_sequence_length:
return batch
if current_sequence_length % target_sequence_length != 0:
raise RuntimeError(
"current sequence of batch must be a multiple of the target sequence length "
"in order to 'melt' the batch"
)
ratio = current_sequence_length // target_sequence_length
new_batch: Dict[str, Any] = {}
for key, value in batch.items():
if isinstance(value, torch.Tensor):
if value.shape == (current_batch_size, current_sequence_length):
new_batch[key] = value.reshape(-1, target_sequence_length)
elif value.shape == (current_batch_size,) or value.shape == (current_batch_size, 1):
new_batch[key] = value.repeat_interleave(ratio)
else:
raise RuntimeError(
f"unable to melt '{key}' tensor in batch with shape '{value.shape}'"
)
elif isinstance(value, list) and len(value) > 0:
new_batch[key] = []
for item in value:
if isinstance(item, list):
if len(item) != current_sequence_length:
raise RuntimeError(f"unexpected item length for '{key}' in batch")
for i in range(ratio):
new_batch[key].append(item[i * ratio : i * ratio + target_sequence_length])
else:
for _ in range(ratio):
new_batch[key].append(item)
else:
raise RuntimeError(f"unexpected item in batch: '{key}={value}'")
return new_batch
[docs]
def truncate_batch(batch: Dict[str, Any], target_sequence_length: int) -> Dict[str, Any]:
"""
Truncate the instances in a batch to ``target_sequence_length``.
"""
current_batch_size, current_sequence_length = batch["input_ids"].shape
if current_sequence_length <= target_sequence_length:
return batch
new_batch: Dict[str, Any] = {}
for key, value in batch.items():
if isinstance(value, torch.Tensor):
if value.shape == (current_batch_size, current_sequence_length):
new_batch[key] = value[:, :target_sequence_length]
elif value.shape == (current_batch_size,) or value.shape == (current_batch_size, 1):
new_batch[key] = value
else:
raise RuntimeError(
f"unable to truncate '{key}' tensor in batch with shape '{value.shape}'"
)
elif isinstance(value, list) and len(value) > 0:
new_batch[key] = []
for item in value:
if isinstance(item, list):
if len(item) != current_sequence_length:
raise RuntimeError(f"unexpected item length for '{key}' in batch")
new_batch[key].append(item[:target_sequence_length])
else:
new_batch[key].append(item)
else:
raise RuntimeError(f"unexpected item in batch: '{key}={value}'")
return new_batch
[docs]
def write_document_indices(data_path: Path, *, dtype, eos_token_id: int) -> Path:
"""
Given a local ".npy" data path from the Dolma toolkit, write a metadata file with start/end indices
of each document within the array.
"""
token_ids = np.memmap(data_path, mode="r", dtype=dtype)
eos_token_locations = (token_ids == eos_token_id).nonzero()[0]
metadata_path = data_path.with_suffix(".csv.gz")
with gzip.open(metadata_path, mode="wt") as f:
start_idx = 0
for eos_token_location in eos_token_locations:
end_idx = eos_token_location + 1
f.write(f"{start_idx},{end_idx}\n")
start_idx = end_idx
return metadata_path
[docs]
def iter_document_indices(
data_path: PathOrStr,
*,
local_cache: Optional[PathOrStr] = None,
use_array_if_local: Optional[bool] = None,
eos_token_id: Optional[int] = None,
bos_token_id: Optional[int] = None,
dtype=None,
) -> Generator[Tuple[int, int], None, None]:
"""
Given a ".npy" data path from the Dolma toolkit, get the list of document start/end indices within
the array.
:param data_path: Path to a ".npy" Dolma toolkit data file.
:param local_cache: Local directory to put downloads into.
:param use_array_if_local: Use the numpy data array to find the document indices if the array
is on the local filesystem and ``eos_token_id`` and ``dtype`` are provided.
This can be a lot faster. Otherwise relies on the metadata file.
:param eos_token_id: The EOS token ID.
Required to use the local data array instead of the metadata file.
:param dtype: The data type of the numpy data array.
Required to use the local data array instead of the metadata file.
"""
if use_array_if_local is None:
if eos_token_id is not None and dtype is not None and not is_url(data_path):
use_array_if_local = True
if use_array_if_local and not is_url(data_path):
if eos_token_id is None or dtype is None:
raise ValueError(
"'eos_token_id' and 'dtype' are required to use the local array for finding document indices"
)
mmap = np.memmap(data_path, mode="r", dtype=dtype)
if bos_token_id is None:
doc_boundaries = (mmap == eos_token_id).nonzero()[0]
else:
doc_boundaries = np.logical_and(
mmap[:-1] == eos_token_id, mmap[1:] == bos_token_id
).nonzero()[0]
if mmap[-1] == eos_token_id:
doc_boundaries = np.append(doc_boundaries, mmap.shape[0] - 1)
start_idx = 0
for idx in doc_boundaries:
end_idx = idx + 1
yield start_idx, end_idx
start_idx = end_idx
else:
metadata_filename = os.path.basename(data_path).replace(".npy", ".csv.gz")
try:
metadata_path = resource_path(
os.path.dirname(data_path),
metadata_filename,
local_cache=local_cache,
)
except FileNotFoundError as e:
raise RuntimeError(
f"Source metadata file '{metadata_filename}' is required to calculate document indices for '{data_path}'. "
"If the source data file is local (on-disk) and 'eos_token_id' and 'dtype' are provided, then the document "
"indices can be inferred from the source file."
) from e
total_tokens: Optional[int] = None
if dtype is not None:
total_tokens = get_file_size(data_path) // dtype(0).itemsize
with gzip.open(metadata_path, "rt") as f:
for line in f:
start_index_str, end_index_str, *_ = line.split(",")
start_index, end_index = int(start_index_str), int(end_index_str)
if total_tokens is not None:
if start_index >= total_tokens:
raise RuntimeError(
f"Document start index {start_index:,d} from metadata file "
f"for source '{data_path}' with {total_tokens:,d} tokens is out-of-bounds"
)
if end_index > total_tokens:
raise RuntimeError(
f"Document end index {end_index:,d} from metadata file "
f"for source '{data_path}' with {total_tokens:,d} tokens is out-of-bounds"
)
yield start_index, end_index
[docs]
def iter_document_indices_with_max_sequence_length(
data_path: PathOrStr,
max_sequence_length: int,
*,
local_cache: Optional[PathOrStr] = None,
use_array_if_local: Optional[bool] = None,
eos_token_id: Optional[int] = None,
bos_token_id: Optional[int] = None,
dtype=None,
long_doc_strategy: LongDocStrategy = LongDocStrategy.truncate,
) -> Generator[Tuple[int, int], None, None]:
"""
Like :func:`iter_document_indices` but will either truncate or split documents that are
longer than ``max_sequence_length``.
"""
for start_idx, end_idx in iter_document_indices(
data_path,
local_cache=local_cache,
use_array_if_local=use_array_if_local,
eos_token_id=eos_token_id,
bos_token_id=bos_token_id,
dtype=dtype,
):
if end_idx - start_idx > max_sequence_length:
if long_doc_strategy == LongDocStrategy.truncate:
yield start_idx, start_idx + max_sequence_length
elif long_doc_strategy == LongDocStrategy.fragment:
for new_start_idx in range(start_idx, end_idx, max_sequence_length):
yield new_start_idx, min(end_idx, new_start_idx + max_sequence_length)
else:
raise NotImplementedError(long_doc_strategy)
else:
yield start_idx, end_idx
[docs]
def get_document_indices(
data_path: PathOrStr, local_cache: Optional[PathOrStr] = None
) -> List[Tuple[int, int]]:
"""
Like :func:`iter_document_indices` but returns a list.
"""
return list(iter_document_indices(data_path, local_cache=local_cache))
[docs]
def load_array_slice(
path: PathOrStr,
start_idx: int,
end_idx: int,
dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64], Type[np.bool_]],
) -> np.ndarray:
"""
Load a slice from a numpy array on disk.
:param path: The path/URL to the array.
:param start_idx: The start index (0-based) of the slice within the array.
:param end_idx: The end index (0-based, exclusive) of the slice within the array.
:param dtype: The numpy datatype of the array.
"""
item_size = dtype(0).itemsize
bytes_start = start_idx * item_size
num_bytes = (end_idx - start_idx) * item_size
buffer = get_bytes_range(path, bytes_start, num_bytes)
return np.frombuffer(buffer, dtype=dtype)
[docs]
def load_array_slice_into_tensor(
path: PathOrStr,
start_idx: int,
end_idx: int,
dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64], Type[np.bool_]],
) -> torch.Tensor:
"""
Read a chunk from a numpy array, returning the chunk as a :class:`torch.Tensor`.
:param path: The path/URL to the array.
:param start_idx: The start index (0-based) of the chunk within the array.
:param end_idx: The end index (0-based, exclusive) of the chunk within the array.
:param dtype: The numpy datatype of the array.
"""
array = load_array_slice(path, start_idx, end_idx, dtype)
if dtype == np.bool_:
return torch.tensor(array)
else:
return torch.tensor(array.astype(np.int_), dtype=torch.long)
[docs]
def get_document_lengths(
input_ids: Union[torch.Tensor, np.ndarray],
eos_token_id: int,
bos_token_id: Optional[int] = None,
) -> torch.Tensor:
"""
Get the length of documents.
:param input_ids: An integer-type tensor of token IDs.
:param eos_token_id: The ID of the EOS token (use to denote document boundaries).
:param bos_token_id: The ID of the BOS token (use to denote document boundaries). When provided,
every document must start with a BOS token.
"""
if not isinstance(input_ids, torch.Tensor):
input_ids = torch.tensor(input_ids.astype(np.int_), dtype=torch.long)
if bos_token_id is None:
doc_boundaries = torch.cat(
[
torch.tensor([-1], dtype=torch.int32),
(input_ids == eos_token_id).nonzero(as_tuple=True)[0].to(dtype=torch.int32),
torch.tensor(
[] if input_ids[-1] == eos_token_id else [input_ids.shape[0] - 1],
dtype=torch.int32,
),
]
)
else:
doc_boundaries = torch.cat(
[
torch.tensor([-1], dtype=torch.int32),
torch.logical_and(input_ids[:-1] == eos_token_id, input_ids[1:] == bos_token_id)
.nonzero(as_tuple=True)[0]
.to(dtype=torch.int32),
torch.tensor([input_ids.shape[0] - 1], dtype=torch.int32),
]
)
return doc_boundaries[1:] - doc_boundaries[:-1]
[docs]
def get_cumulative_document_lengths(doc_lens: torch.Tensor) -> torch.Tensor:
"""
Transform a batched tensor of document lengths into a 1D tensor of cumulative document
lengths for the whole batch.
:param doc_lens: The document lengths, such as those returned by :func:`get_document_lengths`.
"""
return torch.cat(
[
torch.tensor([0], dtype=torch.int32, device=doc_lens.device),
torch.cumsum(doc_lens.masked_select(doc_lens != 0), 0, dtype=torch.int32),
]
)
def iter_batched(
iterable: Iterable[Dict[str, Any]], batch_num_tokens: int
) -> Iterable[Tuple[Dict[str, Any], ...]]:
batch: List[Dict[str, Any]] = []
tokens = 0
shape: Optional[Tuple[int, ...]] = None
for x in iterable:
x_num_tokens = x["input_ids"].numel()
assert x_num_tokens <= batch_num_tokens, f"{x_num_tokens} > {batch_num_tokens}"
if (tokens + x_num_tokens) > batch_num_tokens:
yield tuple(batch)
batch.clear()
tokens = 0
shape = None
batch.append(x)
tokens += x_num_tokens
if shape is not None and shape != x["input_ids"].shape:
raise RuntimeError(
f"Items in batch don't have the same shape! Expected {shape}, "
f"got {tuple(x['input_ids'].shape)}"
)
shape = tuple(x["input_ids"].shape)
if batch:
yield tuple(batch)
[docs]
@contextmanager
def memmap_to_write(
path: Path,
*,
shape: Tuple[int, ...],
dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64], Type[np.bool_]],
) -> Generator[np.ndarray, None, None]:
"""
A context manager for safely writing a numpy memory-mapped array to disk.
The memory-mapped ndarray returned by the context manager will be mapped to a temporary
file until the context exists successfully.
"""
path.parent.mkdir(exist_ok=True, parents=True)
# NOTE: we use 'random.SystemRandom' here to minimize the probability of collisions in temp
# filenames from different runs using the same seed and working directory.
tmp_path = path.with_suffix(f".{random.SystemRandom().randint(0, 2**32)}.npy.tmp")
mmap = np.memmap(tmp_path, dtype=dtype, mode="w+", shape=shape)
try:
yield mmap
except BaseException:
tmp_path.unlink(missing_ok=True)
mmap.flush()
del mmap
try:
tmp_path.replace(path)
except FileNotFoundError:
# Handle potential race condition if multiple processes are trying to replace the same
# 'tmp_path' concurrently, in which case we might get a FileNotFoundError because the
# 'tmp_path' was already moved by another process.
# In this case we'll ignore the error if 'path' already exists.
if not path.is_file():
raise
[docs]
def write_array_to_disk(arr: np.ndarray, path: Path):
"""
Write a numpy array to disk in the same simple format that ``np.memmap`` uses.
"""
with memmap_to_write(
path,
dtype=arr.dtype,
shape=arr.shape,
) as mmap:
mmap[:] = arr
def divide_into_buckets(n: int, b: int) -> List[int]:
buckets: List[int] = []
while (buckets_remaining := b - len(buckets)) > 0:
c = math.ceil(n / buckets_remaining)
n -= c
buckets.append(c)
return buckets
def chunk_array(arr: np.ndarray, chunk_sizes: Sequence[int]) -> List[np.ndarray]:
assert len(arr.shape) == 1
assert sum(chunk_sizes) == arr.shape[0]
offset = 0
chunks = []
for n in chunk_sizes:
chunks.append(arr[offset : offset + n])
offset += n
return chunks
def get_rng(seed: int) -> np.random.Generator:
return np.random.Generator(np.random.PCG64(seed=seed))
[docs]
def bucket_documents(
path: PathOrStr,
target: Path,
*,
buckets: Sequence[int],
eos_token_id: int,
dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]],
indices_dtype: Union[
Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]
] = np.uint32,
) -> Tuple[int, int]:
"""
Bucket documents by sequence lengths in powers of 2. Saving the indices of the bucketed
documents to ``target``.
Returns the number of original documents and the number of new bucketed documents.
"""
max_sequence_length = max(buckets)
min_sequence_length = min(buckets)
total_og_docs = 0
indices = []
for start_idx, end_idx in iter_document_indices(path, eos_token_id=eos_token_id, dtype=dtype):
total_og_docs += 1
bin_decomp = capped_powers_of_2(end_idx - start_idx, max_sequence_length)
for x in bin_decomp:
if x < min_sequence_length:
break
indices.append(start_idx)
indices.append(start_idx + x)
start_idx += x
if not indices:
raise RuntimeError(f"Failed to produce any bucketed documents for source file at '{path}'")
with memmap_to_write(target, dtype=indices_dtype, shape=(len(indices),)) as indices_mmap:
indices_mmap[:] = indices
return total_og_docs, len(indices) // 2
[docs]
def segment_documents_into_instances(
path: PathOrStr,
target: Path,
*,
max_sequence_length: int,
eos_token_id: int,
dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]],
indices_dtype: Union[
Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]
] = np.uint32,
bos_token_id: Optional[int] = None,
sample: Optional[Tuple[int, int]] = None,
) -> Tuple[int, int]:
"""
Segment documents into instances of at most ``sequence_length`` tokens.
Saving the indices of the instances to ``target``.
Sample a subset of the instances if ``sample`` is provided as a tuple of ``(max_instances, seed)``.
Returns the number of original documents and the number of resulting instances documents.
"""
total_og_docs = 0
idx_gen = (
idx
for start_idx, end_idx in iter_document_indices(
path, eos_token_id=eos_token_id, bos_token_id=bos_token_id, dtype=dtype
)
for idx in (start_idx, start_idx + min(end_idx - start_idx, max_sequence_length))
)
indices = np.fromiter(idx_gen, dtype=indices_dtype)
total_og_docs = len(indices) // 2
if sample is not None:
max_instances, seed = sample
rng = get_rng(seed)
indices = rng.choice(indices.reshape(-1, 2), size=max_instances).reshape(-1)
if indices.size == 0:
raise RuntimeError(f"Failed to produce any documents from '{path}'")
with memmap_to_write(target, dtype=indices_dtype, shape=(indices.size,)) as indices_mmap:
indices_mmap[:] = indices
return total_og_docs, len(indices) // 2
def run_worker_func(func, *args, **kwargs):
add_cached_path_clients()
return func(*args, **kwargs)
def get_doc_lengths_from_indices(doc_indices: np.ndarray) -> np.ndarray:
return doc_indices[1::2] - doc_indices[0::2]
def get_labels(batch: Dict[str, Any], label_ignore_index: int = -100) -> torch.Tensor:
# Labels are just input IDs shifted to the left (first item is ignored).
labels, label_mask, attention_mask, instance_mask = (
batch["input_ids"].clone(),
batch.get("label_mask"),
batch.get("attention_mask"),
batch.get("instance_mask"),
)
if label_mask is not None:
labels.masked_fill_(~label_mask, label_ignore_index)
if attention_mask is not None:
labels.masked_fill_(attention_mask == 0.0, label_ignore_index)
if instance_mask is not None:
labels.masked_fill_(~instance_mask.unsqueeze(-1), value=label_ignore_index)
# Shift and pad.
return F.pad(labels[..., 1:], (0, 1, 0, 0), value=label_ignore_index)
[docs]
def find_end_first_consecutive_true(arr: np.ndarray) -> int:
"""Function to find the end position of the first consecutive sequence of True in an array."""
if not arr[0]:
return 0
prog = np.cumsum(arr)
if prog[-1] == len(arr):
return len(arr)
true_locs = np.where(prog[:-1:] == prog[1::])[0]
return true_locs[0] + 1
[docs]
def find_start_last_consecutive_true(arr: np.ndarray) -> int:
"""Function to find the start position of the last consecutive sequence of True in an array."""
reverse = find_end_first_consecutive_true(arr[::-1])
return len(arr) - reverse if reverse > 0 else -1
[docs]
def group_consecutive_values(arr: np.ndarray, stepsize: int = 1) -> List[np.ndarray]:
"""Function to group consecutive values in an array."""
return np.split(arr, np.where(np.diff(arr) != stepsize)[0] + 1)
[docs]
class RepetitionTuple(NamedTuple):
"""Tuple to store information about a periodic sequence."""
start: int
end: int
period: int
times: int
[docs]
def find_periodic_sequences(
arr: np.ndarray, max_period: int, min_period: int = 1, mask_value: int = -1
) -> Generator[RepetitionTuple, None, None]:
"""Function to find periodic sequences in an array.
This function sweeps through the array and checks for sequences of length
[min_period, max_period] that repeat at least 3 times. To do so, it
reshape the array into a matrix with `period` columns and checks if each
row is equal to the previous row. Blocks of repeating rows indicates repeating
sequences.
Because there's no guarantee that the sequences start at the beginning of each
row, it can only detect sequences that repeat at least 3 times. To account
for the fact that sequences may not start at the beginning of each row (or
end at the end of each row), we check the end of the previous row and the
start of the next row to determine the actual start and end positions of the
sequence.
Args:
arr (np.ndarray): The array to search for periodic sequences.
max_period (int): The maximum period to check for.
min_period (int, optional): The minimum period to check for. Defaults to 1.
mask_value (int, optional): The value to use to pad the array. Defaults to -1.
"""
# make sure the mask_value is not in the array
if (arr == mask_value).sum() > 0:
raise ValueError("`mask_value` is in the array")
# no since we can only detect sequences that repeat at least 3 times,
# there is no point in checking for periods greater than 1/3 of the length
max_period = min(max_period, len(arr) // 3)
for period in range(min_period, max_period + 1):
# pad the array so that it can be reshaped into a matrix matching the period
padded_arr = np.pad(arr, (0, period - (len(arr) % period)), constant_values=mask_value)
shaped_arr = padded_arr.reshape(-1, period)
# find rows that are equal to the previous row; these are the possibly-periodic sequences
is_equal_to_prev_row = shaped_arr == np.roll(shaped_arr, shift=1, axis=0)
rows_with_period, *_ = np.where(is_equal_to_prev_row.all(axis=1))
# no sequences found with this period
if len(rows_with_period) == 0:
continue
# this finds the start and end positions of the sequences with period `period`
where_true_consecutive = group_consecutive_values(rows_with_period)
for sequence in where_true_consecutive:
start_row = sequence[0]
end_row = sequence[-1]
# we check if any value at the end of the previous row is True, e.g.:
# [[False, False, True, True]
# [True, True, True, True]]
# (in the case above, start offset is 2). If so, we subtract that from the
# period to get the actual start offset.
start_offset = find_start_last_consecutive_true(is_equal_to_prev_row[start_row - 1])
start_offset = period - start_offset if start_offset > 0 else 0
# same idea as above, we want to compute offset. Only difference is that
# `find_end_first_consecutive_true` already returns the offset, so we don't
# need to subtract from the period.
end_offset = find_end_first_consecutive_true(is_equal_to_prev_row[end_row + 1])
# because we are always comparing with preceding row in
# `is_equal_to_prev_row`, we need to subtract 1 from the row number
start_pos = (start_row - 1) * period - start_offset
# note that the end position is exclusive
end_pos = ((end_row + 1) * period) + end_offset
out = RepetitionTuple(
start=start_pos, end=end_pos, period=period, times=(end_pos - start_pos) // period
)
if out.times > 2:
# cannot accurately determine the period of a sequence that repeats
# less than 3 times with this algorithm
yield out
T = TypeVar("T")
def _take(n: int, iterable: Iterable[T]) -> List[T]:
return list(islice(iterable, n))
[docs]
def chunked(iterable: Iterable[T], n: int) -> Iterable[List[T]]:
"""
Group items in the iterable into chunks of size `n`, at most. This is equivalent to the function
from ``more-itertools`` with the same name and ``strict=False``.
"""
return iter(ft.partial(_take, n, iter(iterable)), [])
#########################################################################################################################
# Implementation of the Optimized Best-Fit Decreasing (OBFD) bin packing algorithm from https://arxiv.org/pdf/2404.10830.
# See Appendix B for a detailed illustration of the algorithm.
#########################################################################################################################
[docs]
@dataclass
class SegmentTreeNode:
weight: int = 0
parent: Optional["SegmentTreeNode"] = None
children: Optional[Tuple["SegmentTreeNode", "SegmentTreeNode"]] = None
leaf_id: Optional[int] = None
@property
def is_root(self) -> bool:
return self.parent is None
@property
def is_leaf(self) -> bool:
return self.children is None
def update(self, weight: Optional[int] = None):
if weight is not None:
assert self.is_leaf
self.weight = weight
else:
assert self.children is not None
self.weight = max(self.children[0].weight, self.children[1].weight)
if self.parent is not None:
self.parent.update()
class SegmentTree:
def __init__(self, N: int):
assert math.log2(N) % 1 == 0, "N should be a power of 2"
self.root_node = SegmentTreeNode()
self.leaf_nodes: List[SegmentTreeNode] = []
max_depth = int(math.log2(N))
leaf_id = 0
queue: deque[Tuple[SegmentTreeNode, int]] = deque([(self.root_node, 0)])
while queue:
parent, depth = queue.popleft()
if depth < max_depth:
parent.children = (SegmentTreeNode(parent=parent), SegmentTreeNode(parent=parent))
queue.append((parent.children[0], depth + 1))
queue.append((parent.children[1], depth + 1))
else:
parent.leaf_id = leaf_id
self.leaf_nodes.append(parent)
leaf_id += 1
assert len(self.leaf_nodes) == N
self.leaf_nodes[-1].update(N)
def query(self, weight: int) -> SegmentTreeNode:
node = self.root_node
while not node.is_leaf:
assert weight <= node.weight
assert node.children is not None
left_child, right_child = node.children
if weight <= left_child.weight:
node = left_child
else:
node = right_child
return node
class InstancePacker:
def __init__(self, max_sequence_length: int):
self.max_sequence_length = max_sequence_length
self.seg_tree = SegmentTree(max_sequence_length)
self.instance_bins: List[List[int]] = []
self.space_to_bins: Dict[int, deque[int]] = defaultdict(deque)
@property
def total_padding(self) -> int:
total_padding = 0
for i in range(1, self.max_sequence_length):
if i in self.space_to_bins:
total_padding += i * len(self.space_to_bins[i])
return total_padding
@property
def total_tokens(self) -> int:
return self.max_sequence_length * len(self.instance_bins) - self.total_padding
def _pack_document(self, document_id: int, document_length: int) -> int:
# Query for best-fit capacity.
best_fit_leaf_id = self.seg_tree.query(document_length).leaf_id
assert best_fit_leaf_id is not None
best_fit_capacity = best_fit_leaf_id + 1
if best_fit_capacity == self.max_sequence_length:
# Need a new bin.
self.instance_bins.append([])
bin_id = len(self.instance_bins) - 1
else:
# Get first bin with the best-fit capacity left.
bins = self.space_to_bins[best_fit_capacity]
bin_id = bins.popleft()
if len(bins) == 0:
self.seg_tree.leaf_nodes[best_fit_capacity - 1].update(weight=0)
# Add document to the target bin.
bin = self.instance_bins[bin_id]
bin.append(document_id)
# Maybe update space-to-bins table and segment tree for bin's new capacity.
bin_space = best_fit_capacity - document_length
if bin_space > 0:
bins = self.space_to_bins[bin_space]
if len(bins) == 0:
self.seg_tree.leaf_nodes[bin_space - 1].update(weight=bin_space)
self.space_to_bins[bin_space].append(bin_id)
return bin_id
def pack_documents(
self, document_indices: np.ndarray
) -> Tuple[List[List[int]], np.ndarray, int]:
if self.instance_bins or self.space_to_bins:
raise RuntimeError(
f"You must call '{self.__class__.__name__}.reset()' before "
f"calling '{self.__class__.__name__}.pack_documents()' again."
)
# Sort document indices by document length, decreasing.
document_lengths = document_indices[:, 1] - document_indices[:, 0]
sorted_index = np.argsort(-1 * document_lengths.astype(np.int64))
document_indices = np.take(document_indices, sorted_index, axis=0)
# Pack documents into instances.
for document_id, (start_idx, end_idx) in enumerate(document_indices):
document_len = int(end_idx - start_idx)
self._pack_document(document_id, document_len)
instances = self.instance_bins # list[list[int]] of document IDs in each instance
return instances, document_indices, self.total_tokens
def reset(self):
self.seg_tree = SegmentTree(self.max_sequence_length)
self.instance_bins.clear()
self.space_to_bins.clear()
[docs]
def pack_documents_into_instances(
*paths: PathOrStr,
max_sequence_length: int,
eos_token_id: int,
dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]],
bos_token_id: Optional[int] = None,
indices_dtype: Union[
Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]
] = np.uint64,
long_doc_strategy: LongDocStrategy = LongDocStrategy.truncate,
) -> Tuple[List[List[int]], np.ndarray, int]:
"""
Pack document from source files into instances of at most ``max_sequence_length`` using
a best-fit-decreasing algorithm described in https://arxiv.org/pdf/2404.10830.
:param paths: Paths/URLs to the source files of token IDs. When multiple sources are given, they'll
be treated as if they've been concatenated together into a single source file.
:param max_sequence_length: The maximum sequence length of each *instance*.
:param eos_token_id: The EOS token ID, used to find document boundaries.
:param bos_token_id: The BOS token ID, used to find document boundaries in conjunction with the EOS
token ID.
:param dtype: The numpy datatype of the source file.
:param indices_dtype: The numpy datatype to use for document indices.
:param long_doc_strategy: Specifies how to handle document that are longer than ``max_sequence_length``.
If set to "truncate" then those documents are just truncated to ``max_sequence_length`` and
the excess tokens are discarded.
If set to "fragment" then those documents are split into smaller documents so that no tokens
are discarded, but you end up with fragmented documents.
:returns: A list of instances, where each instance is a list of document IDs, a 2D array
of the corresponding document start and end indices, with shape ``(num_documents, 2)``,
and the total number of tokens packed into instances.
"""
if len(paths) == 0:
raise RuntimeError("At least one source path must be provided")
def doc_idx_gen() -> Generator[int, None, None]:
start_offset = 0
for path in paths:
for start_idx, end_idx in iter_document_indices_with_max_sequence_length(
path,
max_sequence_length,
eos_token_id=eos_token_id,
bos_token_id=bos_token_id,
dtype=dtype,
long_doc_strategy=long_doc_strategy,
):
yield start_offset + start_idx
yield start_offset + end_idx
start_offset += get_file_size(path) // dtype(0).itemsize
# shape: (num_docs, 2)
document_indices = np.fromiter(doc_idx_gen(), dtype=indices_dtype).reshape(-1, 2)
# Pack documents into instances.
instance_packer = InstancePacker(max_sequence_length)
return instance_packer.pack_documents(document_indices)
[docs]
def attention_mask_to_cache_leftpad(
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""Convert a left-padding attention mask into a cache leftpad for Flash-Attention.
The mask is expected to be a boolean or 0/1 tensor of shape ``(batch, seq_len)`` where
``True``/1 indicates a *valid* token and the padding is on the **left** side of the
sequence (i.e. all padding tokens come *before* all valid tokens).
Returns:
cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts.
"""
if attention_mask.ndim != 2:
raise ValueError(
f"expected 2-D attention_mask (batch, seq_len), got shape {attention_mask.shape}"
)
if attention_mask.dtype != torch.bool:
attention_mask = attention_mask != 0
# Verify prefix-padding property
# Check that once we see a valid token (True), we don't see any padding tokens (False) after it
prefix_ok = (attention_mask.cummax(dim=1).values & ~attention_mask).any().item() is False
if not prefix_ok:
raise ValueError(
"attention_mask must represent *prefix padding* (all padding tokens precede valid tokens) "
"for conversion to flash attention cache leftpad."
)
# Find the first True value in each row (where valid tokens start)
cache_leftpad = attention_mask.int().argmax(dim=-1).int() # (B,)
return cache_leftpad