Source code for olmo_core.io

import logging
import os
import pickle
import re
import shutil
import time
from os import PathLike
from pathlib import Path
from typing import Any, Optional, Union

try:
    from functools import cache
except ImportError:
    from functools import lru_cache as cache

import torch

from .exceptions import OLMoEnvironmentError, OLMoNetworkError

log = logging.getLogger(__name__)

PathOrStr = Union[Path, PathLike, str]

############################################
## Unified API for local and remote files ##
############################################


[docs]def is_url(path: PathOrStr) -> bool: """ Check if a path is a URL. :param path: Path-like object to check. """ return re.match(r"[a-z0-9]+://.*", str(path)) is not None
[docs]def file_size(path: PathOrStr) -> int: """ Get the size of a local or remote file in bytes. :param path: Path/URL to the file. """ if is_url(path): from urllib.parse import urlparse parsed = urlparse(str(path)) if parsed.scheme == "gs": return _gcs_file_size(parsed.netloc, parsed.path.strip("/")) elif parsed.scheme in ("s3", "r2", "weka"): return _s3_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/")) elif parsed.scheme in ("http", "https"): return _http_file_size(str(path)) elif parsed.scheme == "file": return file_size(str(path).replace("file://", "", 1)) else: raise NotImplementedError(f"file size not implemented for '{parsed.scheme}' files") else: return os.stat(path).st_size
[docs]def get_bytes_range(path: PathOrStr, bytes_start: int, num_bytes: int) -> bytes: """ Get a range of bytes from a file. :param source: Path/URL to the file. :param bytes_start: Byte offset to start at. :param num_bytes: Number of bytes to get. """ if is_url(path): from urllib.parse import urlparse parsed = urlparse(str(path)) if parsed.scheme == "gs": return _gcs_get_bytes_range(parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes) elif parsed.scheme in ("s3", "r2", "weka"): return _s3_get_bytes_range( parsed.scheme, parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes ) elif parsed.scheme in ("http", "https"): return _http_get_bytes_range(str(path), bytes_start, num_bytes) elif parsed.scheme == "file": return get_bytes_range(str(path).replace("file://", "", 1), bytes_start, num_bytes) else: raise NotImplementedError(f"file size not implemented for '{parsed.scheme}' files") else: with open(path, "rb") as f: f.seek(bytes_start) return f.read(num_bytes)
[docs]def upload(source: PathOrStr, target: str, save_overwrite: bool = False): """ Upload source file to a target location on GCS or S3. :param source: Path to the file to upload. :param target: Target URL to upload to. :param save_overwrite: Overwrite any existing file. """ from urllib.parse import urlparse source = Path(source) assert source.is_file() num_bytes = file_size(source) log.info(f"Uploading {_format_bytes(num_bytes)} from '{source}' to '{target}'...") parsed = urlparse(target) if parsed.scheme == "gs": _gcs_upload(source, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite) elif parsed.scheme in ("s3", "r2", "weka"): _s3_upload(source, parsed.scheme, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite) else: raise NotImplementedError(f"Upload not implemented for '{parsed.scheme}' scheme")
[docs]def dir_is_empty(dir: PathOrStr) -> bool: """ Check if a local directory is empty. This also returns true if the directory does not exist. :param dir: Path to the local directory. """ dir = Path(dir) if not dir.is_dir(): return True try: next(dir.glob("*")) return False except StopIteration: return True
[docs]def file_exists(path: PathOrStr) -> bool: """ Check if a file exists. :param path: Path/URL to a file. """ if is_url(path): from urllib.parse import urlparse parsed = urlparse(str(path)) if parsed.scheme == "gs": try: _gcs_file_size(parsed.netloc, parsed.path.strip("/")) except FileNotFoundError: return False else: return True elif parsed.scheme in ("s3", "r2", "weka"): try: _s3_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/")) except FileNotFoundError: return False else: return True elif parsed.scheme in ("http", "https"): return _http_file_exists(str(path)) elif parsed.scheme == "file": return file_exists(str(path).replace("file://", "", 1)) else: raise NotImplementedError(f"file_exists not implemented for '{parsed.scheme}' files") else: return Path(path).exists()
[docs]def clear_directory(dir: PathOrStr): """ Clear out the contents of a local or remote directory. GCS (``gs://``) and S3 (``s3://``) URLs are supported. :param dir: Path/URL to the directory. """ if is_url(dir): from urllib.parse import urlparse parsed = urlparse(str(dir)) if parsed.scheme in ("s3", "r2", "weka"): return _s3_clear_directory(parsed.scheme, parsed.netloc, parsed.path.strip("/")) elif parsed.scheme == "file": return clear_directory(str(dir).replace("file://", "", 1)) else: raise NotImplementedError(f"clear_directory not implemented for '{parsed.scheme}' folders") elif Path(dir).is_dir(): shutil.rmtree(dir, ignore_errors=True)
################################### ## Serialization/deserialization ## ###################################
[docs]def serialize_to_tensor(x: Any) -> torch.Tensor: """ Serialize an object to a byte tensor using pickle. :param x: The pickeable object to serialize. """ serialized_bytes = pickle.dumps(x) return torch.frombuffer(bytearray(serialized_bytes), dtype=torch.uint8)
[docs]def deserialize_from_tensor(data: torch.Tensor) -> Any: """ Deserialize an object from a byte tensor using pickle. :param data: The byte tensor to deserialize. """ assert data.dtype == torch.uint8 return pickle.loads(bytearray([int(x.item()) for x in data.flatten()]))
###################### ## Internal helpers ## ###################### def _wait_before_retry(attempt: int): time.sleep(min(0.5 * 2**attempt, 3.0)) def _format_bytes(num: Union[int, float], suffix="B") -> str: for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): if abs(num) < 1024.0: return f"{num:3.1f}{unit}{suffix}" num /= 1024.0 return f"{num:.1f}Yi{suffix}" ###################### ## HTTPS IO helpers ## ###################### def _http_file_size(url: str) -> int: import requests response = requests.head(url, allow_redirects=True) content_length = response.headers.get("content-length") assert content_length return int(content_length) def _http_get_bytes_range(url: str, bytes_start: int, num_bytes: int) -> bytes: import requests response = requests.get(url, headers={"Range": f"bytes={bytes_start}-{bytes_start+num_bytes-1}"}) if response.status_code == 404: raise FileNotFoundError(url) response.raise_for_status() result = response.content # Some web servers silently ignore range requests and send everything assert len(result) == num_bytes, f"expected {num_bytes} bytes, got {len(result)}" return result def _http_file_exists(url: str) -> bool: import requests response = requests.head(url) if response.status_code == 404: return False response.raise_for_status() return True #################### ## GCS IO helpers ## #################### @cache def _get_gcs_client(): from google.cloud import storage as gcs return gcs.Client() def _gcs_file_size(bucket_name: str, key: str) -> int: from google.api_core.exceptions import NotFound storage_client = _get_gcs_client() bucket = storage_client.bucket(bucket_name) blob = bucket.blob(key) try: blob.reload() except NotFound: raise FileNotFoundError(f"gs://{bucket_name}/{key}") assert blob.size is not None return blob.size def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes: int) -> bytes: from google.api_core.exceptions import NotFound storage_client = _get_gcs_client() bucket = storage_client.bucket(bucket_name) blob = bucket.blob(key) try: blob.reload() except NotFound: raise FileNotFoundError(f"gs://{bucket_name}/{key}") return blob.download_as_bytes(start=bytes_start, end=bytes_start + num_bytes - 1) def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False): storage_client = _get_gcs_client() bucket = storage_client.bucket(bucket_name) blob = bucket.blob(key) if not save_overwrite and blob.exists(): raise FileExistsError(f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it.") blob.upload_from_filename(source) ################### ## S3 IO helpers ## ################### @cache def _get_s3_client(scheme: str): import boto3 from botocore.config import Config session = boto3.Session(profile_name=_get_s3_profile_name(scheme)) return session.client( "s3", endpoint_url=_get_s3_endpoint_url(scheme), config=Config(retries={"max_attempts": 10, "mode": "standard"}), use_ssl=not int(os.environ.get("OLMO_NO_SSL", "0")), ) def _get_s3_profile_name(scheme: str) -> Optional[str]: if scheme == "s3": # For backwards compatibility, we assume S3 uses the default profile if S3_PROFILE is not set. return os.environ.get("S3_PROFILE") if scheme == "r2": profile_name = os.environ.get("R2_PROFILE") if profile_name is None: raise OLMoEnvironmentError( "R2 profile name is not set. Did you forget to set the 'R2_PROFILE' env var?" ) return profile_name if scheme == "weka": profile_name = os.environ.get("WEKA_PROFILE") if profile_name is None: raise OLMoEnvironmentError( "WEKA profile name is not set. Did you forget to set the 'WEKA_PROFILE' env var?" ) return profile_name raise NotImplementedError(f"Cannot get profile name for scheme {scheme}") def _get_s3_endpoint_url(scheme: str) -> Optional[str]: if scheme == "s3": return None if scheme == "r2": r2_endpoint_url = os.environ.get("R2_ENDPOINT_URL") if r2_endpoint_url is None: raise OLMoEnvironmentError( "R2 endpoint url is not set. Did you forget to set the 'R2_ENDPOINT_URL' env var?" ) return r2_endpoint_url if scheme == "weka": weka_endpoint_url = os.environ.get("WEKA_ENDPOINT_URL") if weka_endpoint_url is None: raise OLMoEnvironmentError( "WEKA endpoint url is not set. Did you forget to set the 'WEKA_ENDPOINT_URL' env var?" ) return weka_endpoint_url raise NotImplementedError(f"Cannot get endpoint url for scheme {scheme}") def _s3_file_size(scheme: str, bucket_name: str, key: str, max_attempts: int = 3) -> int: from botocore.exceptions import ClientError err: Optional[Exception] = None for attempt in range(1, max_attempts + 1): try: return _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)["ContentLength"] except ClientError as e: if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404: raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e err = e if attempt < max_attempts: log.warning("%s failed attempt %d with retriable error: %s", _s3_file_size.__name__, attempt, err) _wait_before_retry(attempt) raise OLMoNetworkError("Failed to get s3 file size") from err def _s3_get_bytes_range( scheme: str, bucket_name: str, key: str, bytes_start: int, num_bytes: int, max_attempts: int = 3 ) -> bytes: from botocore.exceptions import ClientError, ConnectionError, HTTPClientError err: Optional[Exception] = None for attempt in range(1, max_attempts + 1): try: return ( _get_s3_client(scheme) .get_object( Bucket=bucket_name, Key=key, Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}" )["Body"] .read() ) except ClientError as e: if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404: raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e err = e except (HTTPClientError, ConnectionError) as e: # ResponseStreamingError (subclass of HTTPClientError) can happen as # a result of a failed read from the stream (http.client.IncompleteRead). # Retrying can help in this case. err = e if attempt < max_attempts: log.warning( "%s failed attempt %d with retriable error: %s", _s3_get_bytes_range.__name__, attempt, err ) _wait_before_retry(attempt) # When torch's DataLoader intercepts exceptions, it may try to re-raise them # by recalling their constructor with a single message arg. Torch has some # logic to deal with the absence of a single-parameter constructor, but it # doesn't gracefully handle other possible failures in calling such a constructor # This can cause an irrelevant exception (e.g. KeyError: 'error'), resulting # in us losing the true exception info. To avoid this, we change the exception # to a type that has a single-parameter constructor. raise OLMoNetworkError("Failed to get bytes range from s3") from err def _s3_upload( source: Path, scheme: str, bucket_name: str, key: str, save_overwrite: bool = False, max_attempts: int = 3 ): from botocore.exceptions import ClientError err: Optional[Exception] = None if not save_overwrite: for attempt in range(1, max_attempts + 1): try: _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key) raise FileExistsError( f"s3://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it." ) except ClientError as e: if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404: err = None break err = e if attempt < max_attempts: log.warning("%s failed attempt %d with retriable error: %s", _s3_upload.__name__, attempt, err) _wait_before_retry(attempt) if err is not None: raise OLMoNetworkError("Failed to check object existence during s3 upload") from err try: _get_s3_client(scheme).upload_file(source, bucket_name, key) except ClientError as e: raise OLMoNetworkError("Failed to upload to s3") from e def _s3_clear_directory(scheme: str, bucket_name: str, prefix: str, max_attempts: int = 3): from botocore.exceptions import ClientError if not prefix.endswith("/"): prefix = prefix + "/" err: Optional[Exception] = None for attempt in range(1, max_attempts + 1): try: for o in _get_s3_client(scheme).list_objects_v2(Bucket=bucket_name, Prefix=prefix)["Contents"]: _get_s3_client(scheme).delete_object(Bucket=bucket_name, Key=o["Key"]) return except ClientError as e: if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404: return err = e if attempt < max_attempts: log.warning("%s failed attempt %d with retriable error: %s", _s3_upload.__name__, attempt, err) _wait_before_retry(attempt) raise OLMoNetworkError("Failed to remove S3 directory") from err