Source code for olmo_core.utils

import dataclasses
import gc
import logging
import math
import os
import socket
import sys
import time
import uuid
import warnings
from contextlib import contextmanager
from datetime import datetime, timedelta
from functools import lru_cache
from itertools import cycle, islice
from queue import Queue
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union, cast

import rich
import torch
import torch.distributed as dist
from rich.console import Console, ConsoleRenderable
from rich.highlighter import NullHighlighter
from rich.text import Text
from rich.traceback import Traceback
from torch.utils.flop_counter import FlopCounterMode

from .config import StrEnum
from .exceptions import OLMoCLIError, OLMoEnvironmentError, OLMoError, OLMoThreadError

OLMO_NUM_THREADS_ENV_VAR = "OLMO_NUM_THREADS"
LOG_FILTER_TYPE_ENV_VAR = "LOG_FILTER_TYPE"
RICH_LOGGING_ENV_VAR = "OLMO_RICH_LOGGING"


_log_extra_fields: Dict[str, Any] = {}
_LOGGING_CONFIGURED: bool = False
log = logging.getLogger(__name__)


[docs] def generate_uuid() -> str: """ Generate a unique ID. """ return str(uuid.uuid4())
[docs] def get_default_thread_count() -> int: """ Get the default maximum number of threads allowed. """ env_val = os.environ.get(OLMO_NUM_THREADS_ENV_VAR) if env_val is not None: try: return int(env_val) except ValueError: raise OLMoEnvironmentError( f"Invalid value for {OLMO_NUM_THREADS_ENV_VAR} environment variable ('{env_val}')" ) else: return min(16, (os.cpu_count() or 1) + 4)
[docs] def wait_for(condition: Callable[[], bool], description: str, timeout: float = 10.0): """Wait for the condition function to return True.""" start_time = time.monotonic() while not condition(): time.sleep(0.5) if time.monotonic() - start_time > timeout: raise TimeoutError(f"timed out {description}")
[docs] def apply_to_tensors(fn, container: Any) -> None: """ Recursively apply ``fn`` to all tensors in a container. """ if isinstance(container, torch.Tensor): fn(container) elif isinstance(container, (list, tuple, set)): for x in container: apply_to_tensors(fn, x) elif isinstance(container, dict): for k, v in container.items(): apply_to_tensors(fn, k) apply_to_tensors(fn, v) elif hasattr(container, "__dataclass_fields__"): for f in dataclasses.fields(container): name = f.name apply_to_tensors(fn, getattr(container, name)) elif hasattr(container, "__next__"): for x in container: apply_to_tensors(fn, x)
T = TypeVar("T")
[docs] def move_to_device(o: T, device: torch.device, non_blocking: Optional[bool] = None) -> T: """ Move a tensor or container of tensors to the given device. :param o: The object to move. :param device: The device to move to. """ if non_blocking is None: non_blocking = device.type not in ("cpu", "mps") if isinstance(o, torch.Tensor): return o.to(device, non_blocking=non_blocking) # type: ignore[return-value] elif isinstance(o, dict): return {k: move_to_device(v, device) for k, v in o.items()} # type: ignore[return-value] elif isinstance(o, list): return [move_to_device(x, device) for x in o] # type: ignore[return-value] elif isinstance(o, tuple): return tuple((move_to_device(x, device) for x in o)) # type: ignore[return-value] else: return o
[docs] @torch.compiler.disable def mark_dynamic(x: torch.Tensor, dim: Union[int, Sequence[int]], strict: bool = True): """ Mark a tensor as having dynamic sizes for ``torch.compile()``. """ if strict: torch._dynamo.mark_dynamic(x, dim) else: torch._dynamo.maybe_mark_dynamic(x, dim)
[docs] def get_default_device() -> torch.device: """ Get the default device. """ from .distributed.utils import ( backend_supports_cpu, backend_supports_cuda, is_distributed, ) if is_distributed(): backend = dist.get_backend() if backend_supports_cuda(backend): return torch.device("cuda") elif backend_supports_cpu(backend): return torch.device("cpu") else: raise NotImplementedError(backend) elif torch.cuda.is_available(): return torch.device("cuda") elif torch.mps.is_available(): return torch.device("mps") else: return torch.device("cpu")
def has_compute_capability(major: int, minor: int) -> bool: return torch.cuda.is_available() and torch.cuda.get_device_capability() >= ( major, minor, )
[docs] def seed_all(seed: int): """ Seed all RNG states. """ import random import numpy as np if seed < 0 or seed > 2**32 - 1: raise ValueError(f"Seed {seed} is invalid. It must be on [0; 2^32 - 1]") random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # torch.manual_seed may call manual_seed_all but calling it again here # to make sure it gets called at least once torch.cuda.manual_seed_all(seed)
[docs] def same_storage(x: torch.Tensor, y: torch.Tensor) -> bool: """ Check if two tensors share the same storage. """ x_ptrs = set(e.data_ptr() for e in x.view(-1)) y_ptrs = set(e.data_ptr() for e in y.view(-1)) return (x_ptrs <= y_ptrs) or (y_ptrs <= x_ptrs)
[docs] def gc_cuda(): """ Run garbage collection, including emptying the CUDA cache. """ gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache()
[docs] def has_flash_attn() -> bool: """ Check if flash-attn is available. """ try: import flash_attn # type: ignore del flash_attn return True except ModuleNotFoundError: return False
def set_env_var(name: str, value: str, override: bool = False, secret: bool = False): value_str = "****" if secret else value if name in os.environ: if os.environ[name] != value: if override: log_or_print(log, f"Overriding env var '{name}' to '{value_str}'", logging.WARNING) os.environ[name] = value else: log_or_print( log, f"Will skip setting env var '{name}' since it's already set to '{value_str}'", logging.WARNING, ) else: log_or_print(log, f"Env var '{name}' already set to '{value_str}'") else: log_or_print(log, f"Setting env var '{name}' to '{value_str}'") os.environ[name] = value def log_or_print(logger: logging.Logger, msg: str, level: int = logging.INFO): if _LOGGING_CONFIGURED or ( len(logging.getLogger().handlers) > 0 and logger.getEffectiveLevel() <= level ): logger.log(level, msg) else: print(msg)
[docs] class LogFilterType(StrEnum): """ Determines which ranks are allowed to emit INFO messages. """ rank0_only = "rank0_only" """ INFO messages are only emitted from the global rank 0. """ local_rank0_only = "local_rank0_only" """ INFO messages are only emitted from the local (node) rank 0. """ all_ranks = "all_ranks" """ All ranks emit INFO messages. """
[docs] def log_extra_field(field_name: str, field_value: Any) -> None: """ Add an additional field to each log record. .. note:: For these fields to actually show up in the logs you need to use a formatter/handler that displays them. :param field_name: The name of the field to attach. :param field_value: The value of the field to attach. """ global _log_extra_fields if field_value is None: if field_name in _log_extra_fields: del _log_extra_fields[field_name] else: _log_extra_fields[field_name] = field_value
[docs] def setup_logging( log_filter_type: LogFilterType = LogFilterType.rank0_only, force: bool = False ) -> None: """ Configure logging. .. seealso:: :func:`prepare_cli_environment()` :param log_filter_type: Which ranks emit INFO and below messages. :param force: Force configuring logging even if it was already configured. """ global _LOGGING_CONFIGURED if _LOGGING_CONFIGURED and not force: return from .distributed.utils import get_local_rank, get_rank log_extra_field("hostname", socket.gethostname()) log_extra_field("local_rank", get_local_rank()) log_extra_field("global_rank", get_rank()) old_log_record_factory = logging.getLogRecordFactory() def log_record_factory(*args, **kwargs) -> logging.LogRecord: record = old_log_record_factory(*args, **kwargs) for field_name, field_value in _log_extra_fields.items(): setattr(record, field_name, field_value) return record logging.setLogRecordFactory(log_record_factory) handler: logging.Handler rich_logging_env = os.environ.get(RICH_LOGGING_ENV_VAR) if rich_logging_env is not None: use_rich_logging = rich_logging_env.lower() not in ("0", "false", "no", "off") else: # NOTE: Beaker supports rich logging now. use_rich_logging = not ( os.environ.get("BEAKER_EXPERIMENT_ID") is None and ( os.environ.get("DEBIAN_FRONTEND", None) == "noninteractive" or not sys.stdout.isatty() ) ) if not use_rich_logging: handler = logging.StreamHandler(sys.stdout) formatter = logging.Formatter( "%(asctime)s\t%(hostname)s:%(local_rank)s\t%(name)s:%(lineno)s\t%(levelname)s\t%(message)s" ) formatter.default_time_format = "%Y-%m-%d %H:%M:%S" formatter.default_msec_format = "%s.%03d" handler.setFormatter(formatter) else: handler = _RichHandler() def rank0_filter(record: logging.LogRecord) -> int: if record.levelno > logging.INFO: return 1 if getattr(record, "global_rank", 0) == 0: return 1 else: return 0 def local_rank0_filter(record: logging.LogRecord) -> int: if record.levelno > logging.INFO: return 1 if getattr(record, "local_rank", 0) == 0: return 1 else: return 0 if log_filter_type == LogFilterType.rank0_only: filter = rank0_filter elif log_filter_type == LogFilterType.local_rank0_only: filter = local_rank0_filter # type: ignore elif log_filter_type == LogFilterType.all_ranks: filter = None else: raise ValueError(log_filter_type) if filter is not None: handler.addFilter(filter) # type: ignore logging.basicConfig(handlers=[handler], level=logging.INFO, force=True) logging.captureWarnings(True) logging.getLogger("urllib3").setLevel(logging.ERROR) logging.getLogger("google").setLevel(logging.WARNING) _LOGGING_CONFIGURED = True
[docs] def logging_configured() -> bool: """ Returns ``True`` if logging has been configured (like with :func:`setup_logging()`), otherwise returns ``False``. """ if _LOGGING_CONFIGURED: return True else: # Otherwise check if the root logger has any handlers. return len(logging.getLogger().handlers) > 0
[docs] def excepthook(exctype, value, traceback): """ Used to patch ``sys.excepthook`` in order to log exceptions. Use :func:`install_excepthook()` to install this. """ if issubclass(exctype, KeyboardInterrupt): sys.__excepthook__(exctype, value, traceback) elif issubclass(exctype, OLMoCLIError): rich.get_console().print(f"[yellow]{value}[/]", highlight=False) elif issubclass(exctype, OLMoError): rich.get_console().print(Text(f"{exctype.__name__}:", style="red"), value, highlight=False) else: log.critical( "Uncaught %s: %s", exctype.__name__, value, exc_info=(exctype, value, traceback) )
[docs] def install_excepthook(): """ Install the custom :func:`excepthook`. .. seealso:: :func:`prepare_cli_environment()` """ sys.excepthook = excepthook
[docs] def filter_warnings(): """ Configure warning filters for warnings we don't need to see. .. seealso:: :func:`prepare_cli_environment()` """ # Filter internal deprecation warnings from torch warnings.filterwarnings( action="ignore", category=UserWarning, message="torch.distributed.*_base is a private function and will be deprecated.*", ) warnings.filterwarnings( action="ignore", category=UserWarning, message="TypedStorage is deprecated.*", ) warnings.filterwarnings( action="ignore", category=UserWarning, message="Please use DTensor instead.*", ) warnings.filterwarnings( action="ignore", category=UserWarning, message="Synchronization debug mode is a prototype feature.*", module="torch.cuda", ) warnings.filterwarnings( action="ignore", category=UserWarning, message="TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect .*", ) warnings.filterwarnings( action="ignore", category=FutureWarning, message="You are using `torch.load` with `weights_only=False`.*", ) # flash_attn warnings. warnings.filterwarnings( action="ignore", category=FutureWarning, module="flash_attn.ops.triton.layer_norm", ) # Torchvision warnings. We don't actually use torchvision. warnings.filterwarnings( action="ignore", message="failed to load.*", module="torchvision.io.image", )
[docs] def set_env_variables(): """ Set common needed env vars if they're not already set. .. seealso:: :func:`prepare_cli_environment()` """ set_env_var("OMP_NUM_THREADS", "8") set_env_var("TOKENIZERS_PARALLELISM", "false")
[docs] def prepare_cli_environment(log_filter_type: Optional[LogFilterType] = None): """ Prepare the environment for a script/CLI. This should be called at the very beginning of the script/command, like at the top of the ``if __name__ == "__main__": ...`` block. Internally this calls: - :func:`setup_logging()` - :func:`install_excepthook()` - :func:`filter_warnings()` - :func:`set_env_variables()` - :func:`~olmo_core.io.add_cached_path_clients()` .. tip:: If you're looking to setup the environment specifically for distributed training, see :func:`~olmo_core.train.prepare_training_environment` instead. :param log_filter_type: Determines which ranks are allowed to emit log messages below the ``WARNING`` level. You can also configure this through the env var ``LOG_FILTER_TYPE``. If neither are set, this defaults to "rank0_only". .. note:: All ranks will always emit messages at the ``WARNING`` level or higher. """ from .io import add_cached_path_clients if log_filter_type is None: log_filter_type = LogFilterType(os.environ.get(LOG_FILTER_TYPE_ENV_VAR, "rank0_only")) rich.reconfigure(width=max(rich.get_console().width, 180), soft_wrap=True) setup_logging(log_filter_type=log_filter_type) install_excepthook() filter_warnings() set_env_variables() add_cached_path_clients()
class _RichHandler(logging.Handler): """ A simplified version of rich.logging.RichHandler from https://github.com/Textualize/rich/blob/master/rich/logging.py """ def __init__( self, *, level: Union[int, str] = logging.NOTSET, console: Optional[Console] = None, markup: bool = False, ) -> None: super().__init__(level=level) self.console = console or rich.get_console() self.highlighter = NullHighlighter() self.markup = markup def emit(self, record: logging.LogRecord) -> None: try: if hasattr(record.msg, "__rich__") or hasattr(record.msg, "__rich_console__"): self.console.print(record.msg) else: msg: Any = record.msg if isinstance(record.msg, str): msg = self.render_message(record=record, message=record.getMessage()) renderables = [ self.get_time_text(record), self.get_level_text(record), self.get_location_text(record), msg, ] if record.exc_info is not None: tb = Traceback.from_exception(*record.exc_info) # type: ignore renderables.append(tb) self.console.print(*renderables) except Exception: self.handleError(record) def render_message(self, *, record: logging.LogRecord, message: str) -> ConsoleRenderable: use_markup = getattr(record, "markup", self.markup) message_text = Text.from_markup(message) if use_markup else Text(message) highlighter = getattr(record, "highlighter", self.highlighter) if highlighter: message_text = highlighter(message_text) return message_text def get_time_text(self, record: logging.LogRecord) -> Text: log_time = datetime.fromtimestamp(record.created) time_str = log_time.strftime("[%Y-%m-%d %X]") return Text(time_str, style="log.time", end=" ") def get_level_text(self, record: logging.LogRecord) -> Text: level_name = record.levelname level_text = Text.styled(level_name.ljust(8), f"logging.level.{level_name.lower()}") level_text.style = "log.level" level_text.end = " " return level_text def get_location_text(self, record: logging.LogRecord) -> Text: name_and_line = f"{record.name}:{record.lineno}" if record.name != "root" else "root" text = f"[{name_and_line}, {record.hostname}, rank={record.local_rank}]" # type: ignore return Text(text, style="log.path")
[docs] def threaded_generator(g, maxsize: int = 16, thread_name: Optional[str] = None): """ Wraps a generator ``g`` and runs it in a thread. """ q: Queue = Queue(maxsize=maxsize) sentinel = object() def fill_queue(): try: for value in g: q.put(value) except Exception as e: q.put(e) finally: q.put(sentinel) thread_name = thread_name or repr(g) thread = Thread(name=thread_name, target=fill_queue, daemon=True) thread.start() for x in iter(q.get, sentinel): if isinstance(x, Exception): raise OLMoThreadError(f"generator thread {thread_name} failed") from x else: yield x
[docs] def roundrobin(*iterables): """ Call the given iterables in a round-robin fashion. For example: ``roundrobin('ABC', 'D', 'EF') --> A D E B F C`` """ # Adapted from https://docs.python.org/3/library/itertools.html#itertools-recipes num_active = len(iterables) nexts = cycle(iter(it).__next__ for it in iterables) while num_active: try: for next in nexts: yield next() except StopIteration: # Remove the iterator we just exhausted from the cycle. num_active -= 1 nexts = cycle(islice(nexts, num_active))
def powers_of_2(x: int) -> List[int]: powers: List[int] = [] i = 1 while i <= x: if i & x: powers.insert(0, i) i <<= 1 return powers def capped_powers_of_2(x: int, cap: int) -> List[int]: powers = [] for i in powers_of_2(x): if i > cap: assert i % cap == 0 for _ in range(i // cap): powers.append(cap) else: powers.append(i) return powers def format_float(value: float) -> str: if math.isnan(value): return "nan" elif math.isinf(value): return "inf" if value > 0 else "-inf" abs_value = abs(value) if abs_value == 0.0: return "0.0" elif abs_value >= 1e9: for suffix, factor in (("E", 1e18), ("P", 1e15), ("T", 1e12), ("B", 1e9)): if abs_value >= factor: scaled = value / factor abs_scaled = abs(scaled) if abs_scaled > 100: decimals = 1 elif abs_scaled > 10: decimals = 2 elif abs_scaled > 1: decimals = 3 else: decimals = 4 scaled_str = f"{scaled:,.{decimals}f}" return f"{scaled_str}{suffix}" if abs_value < 0.0001: return f"{value:.2E}" elif abs_value >= 1000: return f"{int(value):,d}" elif abs_value > 100: return f"{value:.1f}" elif abs_value > 10: return f"{value:.2f}" elif abs_value > 1: return f"{value:.3f}" else: return f"{value:.4f}"
[docs] def format_int(count: int) -> str: """Format a large integer into a more human-readable string.""" if count < 1_000: return f"{count}" elif count < 1_000_000: return f"{count / 1_000:.1f}K".replace(".0", "") elif count < 1_000_000_000: return f"{count / 1_000_000:.1f}M".replace(".0", "") elif count < 1_000_000_000_000: return f"{count / 1_000_000_000:.1f}B".replace(".0", "") else: return f"{count / 1_000_000_000_000:.1f}T".replace(".0", "")
def format_timedelta(td: timedelta) -> str: breakdown = [] if td.days > 0: breakdown.append(f"{td.days}d") hours = td.seconds // 3600 if hours > 0: breakdown.append(f"{hours}h") minutes = (td.seconds % 3600) // 60 if minutes > 0: breakdown.append(f"{minutes}m") seconds = td.seconds % 60 if seconds > 0: breakdown.append(f"{seconds}s") if breakdown: return ", ".join(breakdown) else: return "0s"
[docs] def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]: """ Flatten a nested dictionary with strings keys using dot notation. """ out = {} def add_sub_dict(prefix: str, sub_dict: Dict[str, Any]): for k, v in sub_dict.items(): if isinstance(v, dict): add_sub_dict(f"{prefix}.{k}", v) else: out[f"{prefix}.{k}"] = v for k, v in d.items(): if isinstance(v, dict): add_sub_dict(k, v) else: out[k] = v return out
[docs] @contextmanager def cuda_sync_debug_mode(debug_mode: Union[int, str]): """ A context manager for temporarily setting the CUDA sync debug mode. """ current_mode: Optional[int] = None try: if torch.cuda.is_available(): current_mode = torch.cuda.get_sync_debug_mode() torch.cuda.set_sync_debug_mode(debug_mode) yield finally: if current_mode is not None: torch.cuda.set_sync_debug_mode(current_mode)
[docs] def get_element_size(dtype: torch.dtype) -> int: """ Get the size in bytes of element of the given PyTorch dtype. """ return torch._utils._element_size(dtype) # type: ignore
def ensure_multiple_of(x: int, of: int) -> int: return of * math.ceil(x / of) @lru_cache(maxsize=128) def log_once(logger: logging.Logger, msg: str, *args, level: int = logging.INFO, **kwargs): logger.log(level, msg, *args, **kwargs) @lru_cache(maxsize=128) def warn_once(msg: str, *args, **kwargs): warnings.warn(msg, *args, **kwargs)
[docs] def info_value_of_dtype(dtype: torch.dtype): """ Returns the `finfo` or `iinfo` object of a given PyTorch data type. Does not allow torch.bool. """ if dtype == torch.bool: raise TypeError("Does not support torch.bool") elif dtype.is_floating_point: return torch.finfo(dtype) else: return torch.iinfo(dtype)
[docs] def min_value_of_dtype(dtype: torch.dtype): """ Returns the minimum value of a given PyTorch data type. Does not allow torch.bool. """ return info_value_of_dtype(dtype).min
_CUDA_STREAMS: Dict[str, torch.cuda.Stream] = {} def get_or_init_stream(id: str, priority: int = 0) -> torch.cuda.Stream: # NOTE: pri_idx = std::clamp(-priority, 0, max_stream_priorities - 1) # pytorch/c10/cuda/CUDAStream.cpp # it turns out priority>=0 is the same as default priority=0 global _CUDA_STREAMS if id in _CUDA_STREAMS: return _CUDA_STREAMS[id] else: stream = cast(torch.cuda.Stream, torch.cuda.Stream(priority=priority)) _CUDA_STREAMS[id] = stream return stream def record_flops( model: torch.nn.Module, inp: Union[torch.Tensor, tuple], with_backward: bool = False, display: bool = False, ) -> int: # Source: https://alessiodevoto.github.io/Compute-Flops-with-Pytorch-built-in-flops-counter/ istrain = model.training model.eval() inp = inp if isinstance(inp, torch.Tensor) else torch.randn(inp) # FlopCounterMode has some limitations, for example if activation recomputation is used, # it will count the flops for the recomputation as well. This applies to SDPA kernels as well, # since they use recomputation internally (a la flash attention). Additionally, if custom kernels # are used, it will not count the flops for them unless a custom flop formula is registered. # https://github.com/pytorch/pytorch/issues/123800 flop_counter = FlopCounterMode(display=display, depth=999999) with flop_counter: if with_backward: model(inp).sum().backward() else: model(inp) total_flops = flop_counter.get_total_flops() if istrain: model.train() return total_flops