import io
import logging
import os
import pickle
import re
import shutil
import ssl
import time
from collections import deque
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import wraps
from pathlib import Path
from typing import Any, Callable, Generator, List, Optional, Tuple, Type, Union
try:
from functools import cache
except ImportError:
from functools import lru_cache as cache
import requests
import torch
from cached_path import cached_path
from cached_path.schemes import S3Client, SchemeClient, add_scheme_client
from requests.adapters import HTTPAdapter
from rich.progress import track
from .aliases import PathOrStr
from .exceptions import (
OLMoEnvironmentError,
OLMoInvalidRangeRequestError,
OLMoNetworkError,
OLMoUploadError,
)
from .fs_cache import maybe_cache
log = logging.getLogger(__name__)
############################################
## Unified API for local and remote files ##
############################################
[docs]
def normalize_path(path: PathOrStr) -> str:
"""
Normalize a path/URL.
:param path: The path/URL to normalize.
"""
return str(path).rstrip("/").replace("file://", "")
[docs]
def join_path(path: PathOrStr, *paths: PathOrStr) -> PathOrStr:
"""
Join two or more paths.
:returns: The joined result.
"""
if not paths:
return path
for p in paths:
if is_url(path):
path = f"{normalize_path(path)}/{normalize_path(p)}"
else:
path = Path(path) / p
return path
[docs]
def get_parent(path: PathOrStr) -> PathOrStr:
"""
Get the parent directory of a path.
:param path: The path/URL to get the parent of.
"""
if is_url(path):
path = str(normalize_path(path))
if path.count("/") > 2:
return "/".join(path.split("/")[:-1])
else:
return path
else:
return Path(normalize_path(path)).parent
[docs]
def resource_path(folder: PathOrStr, fname: str, local_cache: Optional[PathOrStr] = None) -> Path:
"""
Returns an actual path for local or remote file, potentially downloading it if a copy doesn't
exist locally yet.
"""
folder = normalize_path(folder)
if local_cache is not None and (local_path := Path(local_cache) / fname).is_file():
log.info(f"Found local cache of {fname} at {local_path}")
return local_path
else:
return cached_path(f"{folder}/{fname}", quiet=True)
[docs]
def is_url(path: PathOrStr) -> bool:
"""
Check if a path is a URL.
:param path: Path-like object to check.
"""
path = normalize_path(path)
return re.match(r"[a-z0-9]+://.*", str(path)) is not None
[docs]
@maybe_cache(condition=is_url)
def get_file_size(path: PathOrStr) -> int:
"""
Get the size of a local or remote file in bytes.
.. warning::
Uses caching if the argument is URL if the filesystem cache is enabled
(see :func:`olmo_core.fs_cache.maybe_cache`).
:param path: Path/URL to the file.
"""
path = normalize_path(path)
if is_url(path):
from urllib.parse import urlparse
parsed = urlparse(str(path))
if parsed.scheme == "gs":
return _gcs_file_size(parsed.netloc, parsed.path.strip("/"))
elif parsed.scheme in ("s3", "r2", "weka"):
return _s3_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
elif parsed.scheme in ("http", "https"):
return _http_file_size(str(path))
elif parsed.scheme == "file":
return get_file_size(str(path).replace("file://", "", 1))
else:
raise NotImplementedError(f"file size not implemented for '{parsed.scheme}' files")
else:
return os.stat(path).st_size
[docs]
def get_bytes_range(path: PathOrStr, bytes_start: int, num_bytes: int) -> bytes:
"""
Get a range of bytes from a local or remote file.
:param source: Path/URL to the file.
:param bytes_start: Byte offset to start at.
:param num_bytes: Number of bytes to get.
"""
path = normalize_path(path)
if is_url(path):
from urllib.parse import urlparse
parsed = urlparse(str(path))
if parsed.scheme == "gs":
return _gcs_get_bytes_range(
parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes
)
elif parsed.scheme in ("s3", "r2", "weka"):
return _s3_get_bytes_range(
parsed.scheme, parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes
)
elif parsed.scheme in ("http", "https"):
return _http_get_bytes_range(str(path), bytes_start, num_bytes)
elif parsed.scheme == "file":
return get_bytes_range(str(path).replace("file://", "", 1), bytes_start, num_bytes)
else:
raise NotImplementedError(f"file size not implemented for '{parsed.scheme}' files")
else:
with open(path, "rb") as f:
f.seek(bytes_start)
return f.read(num_bytes)
[docs]
def upload(source: PathOrStr, target: str, save_overwrite: bool = False, quiet: bool = False):
"""
Upload source file to a target location on GCS or S3.
:param source: Path to the file to upload.
:param target: Target URL to upload to.
:param save_overwrite: Overwrite any existing file.
"""
from urllib.parse import urlparse
source = Path(normalize_path(source))
assert source.is_file()
num_bytes = get_file_size(source)
if not quiet:
log.info(f"Uploading {_format_bytes(num_bytes)} from '{source}' to '{target}'...")
parsed = urlparse(target)
if parsed.scheme == "gs":
_gcs_upload(source, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite)
elif parsed.scheme in ("s3", "r2", "weka"):
_s3_upload(
source,
parsed.scheme,
parsed.netloc,
parsed.path.strip("/"),
save_overwrite=save_overwrite,
)
else:
raise NotImplementedError(f"Upload not implemented for '{parsed.scheme}' scheme")
if not quiet:
log.info(f"Uploaded {_format_bytes(num_bytes)} to '{target}'")
[docs]
def copy_file(
source: PathOrStr, target: PathOrStr, save_overwrite: bool = False, quiet: bool = False
):
"""
Copy a file from ``source`` to ``target``.
:param source: The path/URL to the source file.
:param target: The path/URL to the target location.
:param save_overwrite: Overwrite any existing file.
:raises FileNotFoundError: If the ``source`` file doesn't exist.
:raises FileExistsError: If the ``target`` already exists and ``save_overwrite=False``.
"""
source = normalize_path(source)
target = normalize_path(target)
local_source = cached_path(source, quiet=True)
if is_url(target):
upload(local_source, target, save_overwrite=save_overwrite, quiet=quiet)
else:
target = Path(target)
if file_exists(target):
if not save_overwrite:
raise FileExistsError(target)
else:
target.unlink(missing_ok=True)
else:
target.parent.mkdir(exist_ok=True, parents=True)
try:
os.link(local_source, target, follow_symlinks=True)
except OSError:
# Can't hard link across devices, so copy instead.
shutil.copyfile(local_source, target, follow_symlinks=True)
[docs]
def copy_dir(
source: PathOrStr,
target: PathOrStr,
save_overwrite: bool = False,
num_threads: Optional[int] = None,
quiet: bool = False,
):
"""
Copy a directory from ``source`` to ``target``.
:param source: The path/URL to the source directory.
:param target: The path/URL to the target location.
:param save_overwrite: Overwrite any existing files.
:param num_threads: The number of threads to use.
:raises FileNotFoundError: If the ``source`` dir doesn't exist.
:raises FileExistsError: If any source files already exist in the ``target`` and ``save_overwrite=False``.
"""
source = normalize_path(source)
target = normalize_path(target)
if num_threads is None:
from .utils import get_default_thread_count
num_threads = get_default_thread_count()
with ThreadPoolExecutor(max_workers=num_threads) as executor:
if not quiet:
log.info(f"Collecting source files from '{source}' to copy to '{target}'...")
futures = []
for source_path in list_directory(source, recurse=True, include_dirs=False):
assert source_path.startswith(source)
relative_source_path = source_path.replace(source, "", 1).lstrip("/")
target_path = join_path(target, relative_source_path)
futures.append(
executor.submit(
copy_file, source_path, target_path, save_overwrite=save_overwrite, quiet=True
)
)
if not quiet:
log.info(f"Collected {len(futures)} source files to copy")
deque(
track(
(f.result() for f in as_completed(futures)),
description="Copying source files...",
disable=quiet,
total=len(futures),
),
maxlen=0,
)
[docs]
def dir_is_empty(dir: PathOrStr) -> bool:
"""
Check if a local or remote directory is empty.
This also returns true if the directory does not exist.
:param dir: Path/URL to the directory.
"""
try:
next(list_directory(dir))
return False
except (StopIteration, FileNotFoundError):
return True
[docs]
def file_exists(path: PathOrStr) -> bool:
"""
Check if a local or remote file exists.
:param path: Path/URL to a file.
"""
path = normalize_path(path)
if is_url(path):
from urllib.parse import urlparse
parsed = urlparse(str(path))
if parsed.scheme == "gs":
try:
_gcs_file_size(parsed.netloc, parsed.path.strip("/"))
except FileNotFoundError:
return False
else:
return True
elif parsed.scheme in ("s3", "r2", "weka"):
try:
_s3_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
except FileNotFoundError:
return False
else:
return True
elif parsed.scheme in ("http", "https"):
return _http_file_exists(str(path))
else:
raise NotImplementedError(f"file_exists not implemented for '{parsed.scheme}' files")
else:
return Path(path).exists()
[docs]
def remove_file(path: PathOrStr):
"""
Remove a local or remote file.
:param path: The path or URL to the file.
:raises FileNotFoundError: If the file doesn't exist.
"""
path = normalize_path(path)
if is_url(path):
from urllib.parse import urlparse
parsed = urlparse(str(path))
if parsed.scheme == "gs":
return _gcs_remove_file(parsed.netloc, parsed.path.strip("/"))
elif parsed.scheme in ("s3", "r2", "weka"):
return _s3_remove_file(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
else:
raise NotImplementedError(f"remove_file not implemented for '{parsed.scheme}' files")
else:
Path(path).unlink()
[docs]
def clear_directory(dir: PathOrStr, force: bool = False):
"""
Clear out the contents of a local or remote directory.
.. warning::
This function is potentially very destructive!
By default, for safety, this raise a :class:`ValueError` if you attempt to clear a remote
directory too close to the root of the bucket. Set ``force=True`` to override.
:param dir: Path/URL to the directory.
:param force: See note about safety.
"""
dir = normalize_path(dir)
if is_url(dir):
from urllib.parse import urlparse
parsed = urlparse(str(dir))
if parsed.scheme in ("s3", "r2", "weka", "gs"):
prefix = parsed.path.strip("/")
# For safety (so people don't accidentally delete a whole bunch of important data),
# ensure prefix is at least 2 folders deep.
if not force and prefix.count("/") < 2:
raise ValueError(
"For safety, clearing a remote directory this close to the root of a bucket is "
"not allowed by default. To override this behavior set ``force=True``."
)
if parsed.scheme == "gs":
prefix = parsed.path.strip("/")
return _gcs_clear_directory(parsed.netloc, prefix)
elif parsed.scheme in ("s3", "r2", "weka"):
prefix = parsed.path.strip("/")
return _s3_clear_directory(parsed.scheme, parsed.netloc, prefix)
else:
raise NotImplementedError(
f"clear_directory not implemented for '{parsed.scheme}' folders"
)
elif Path(dir).is_dir():
shutil.rmtree(dir, ignore_errors=True)
[docs]
def list_directory(
dir: PathOrStr, recurse: bool = False, include_files: bool = True, include_dirs: bool = True
) -> Generator[str, None, None]:
"""
List the contents of a local or remote directory.
If ``recurse=False``, only the immediate children of the directory are returned, otherwise
every sub-folder is recursed into.
:param dir: Path/URL to the directory.
:param recurse: Whether to recurse into sub-folders.
:param include_files: Include regular files in the results.
:param include_dirs: Include directories in the results.
:returns: A generator over paths in the directory. If the ``dir`` is a URL, the results will be
full URLs. If the ``dir`` is a local path, the results will be of the form ``join_path(dir, p)``.
:raises FileNotFoundError: If the ``source`` file doesn't exist.
"""
dir = normalize_path(dir)
if not is_url(dir):
for p in Path(dir).iterdir():
if (p.is_file() and include_files) or (p.is_dir() and include_dirs):
yield str(p)
if recurse and p.is_dir():
yield from list_directory(
p, recurse=True, include_files=include_files, include_dirs=include_dirs
)
else:
from urllib.parse import urlparse
parsed = urlparse(dir)
if parsed.scheme == "gs":
yield from _gcs_list_directory(
parsed.netloc,
parsed.path.strip("/"),
recurse=recurse,
include_files=include_files,
include_dirs=include_dirs,
)
elif parsed.scheme in ("s3", "r2", "weka"):
yield from _s3_list_directory(
parsed.scheme,
parsed.netloc,
parsed.path.strip("/"),
recurse=recurse,
include_files=include_files,
include_dirs=include_dirs,
)
else:
raise NotImplementedError(
f"list_directory size not implemented for '{parsed.scheme}' URLs"
)
[docs]
def glob_directory(pattern: str) -> Generator[str, None, None]:
"""
Similar to ``glob.glob()`` from the standard library, but works with remote directories as well.
.. warning::
Only a subset of glob patterns are supported. Specifically, ``*`` and ``**`` wildcards,
which the follow the semantics defined here https://docs.python.org/3/library/pathlib.html#pattern-language.
"""
# Pull out base directory from pattern by finding the first part before any wildcard.
# Split by '/' and take path components until we hit one with a wildcard.
parts = pattern.split("/")
base_parts = []
for part in parts:
if "*" in part:
break
base_parts.append(part)
dir = "/".join(base_parts) if base_parts else "."
# Translate the glob pattern into a regex.
# For example, "src/examples/**/*.py" --> "^src/examples/.*[^/]*\\.py$".
pattern_regex = re.compile(
"^"
+ re.escape(pattern).replace(r"\*\*/", ".*").replace(r"\*\*", ".*").replace(r"\*", "[^/]*")
+ "$"
)
for path in list_directory(dir, recurse="**" in pattern):
if pattern_regex.match(path):
yield path
[docs]
@maybe_cache(condition=is_url)
def deterministic_glob_directory(pattern: str) -> List[str]:
"""
Like :func:`glob_directory` but returns a sorted list for deterministic ordering.
.. warning::
Uses caching if the argument is URL if the filesystem cache is enabled
(see :func:`olmo_core.fs_cache.maybe_cache`).
"""
return sorted(glob_directory(pattern))
[docs]
def init_client(remote_path: str):
"""
Initialize the right client for the given remote resource. This is helpful to avoid threading issues
with boto3.
"""
if remote_path.startswith("s3://"):
_get_s3_client("s3")
elif remote_path.startswith("r2://"):
_get_s3_client("r2")
elif remote_path.startswith("weka://"):
_get_s3_client("weka")
###################################
## Serialization/deserialization ##
###################################
[docs]
def serialize_to_tensor(x: Any) -> torch.Tensor:
"""
Serialize an object to a byte tensor using pickle.
:param x: The pickeable object to serialize.
"""
serialized_bytes = pickle.dumps(x)
return torch.frombuffer(bytearray(serialized_bytes), dtype=torch.uint8)
[docs]
def deserialize_from_tensor(data: torch.Tensor) -> Any:
"""
Deserialize an object from a byte tensor using pickle.
:param data: The byte tensor to deserialize.
"""
assert data.dtype == torch.uint8
return pickle.loads(bytearray([int(x.item()) for x in data.flatten()]))
######################
## Internal helpers ##
######################
def _wait_before_retry(attempt: int):
time.sleep(min(0.5 * 2**attempt, 3.0))
def _format_bytes(num: Union[int, float], suffix="B") -> str:
for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"):
if abs(num) < 1024.0:
return f"{num:3.1f}{unit}{suffix}"
num /= 1024.0
return f"{num:.1f}Yi{suffix}"
def retriable(
max_attempts: int = 3,
retriable_errors: Tuple[Type[Exception], ...] = (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.ChunkedEncodingError,
),
retry_condition: Optional[Callable[[Exception], bool]] = None,
):
def decorator(func):
@wraps(func)
def new_func(*args, **kwargs):
for attempt in range(1, max_attempts + 1):
try:
return func(*args, **kwargs)
except Exception as exc:
if isinstance(exc, retriable_errors) or (
retry_condition is not None and retry_condition(exc)
):
if attempt >= max_attempts:
# When torch's DataLoader intercepts exceptions, it may try to re-raise them
# by recalling their constructor with a single message arg. Torch has some
# logic to deal with the absence of a single-parameter constructor, but it
# doesn't gracefully handle other possible failures in calling such a constructor
# This can cause an irrelevant exception (e.g. KeyError: 'error'), resulting
# in us losing the true exception info. To avoid this, we change the exception
# to a type that has a single-parameter constructor.
raise OLMoNetworkError(
f"'{func.__name__}' failed {max_attempts} attempts with: {exc}"
) from exc
else:
log.warning(
f"'{func.__name__}' failed attempt {attempt} with retriable error: {exc}"
)
_wait_before_retry(attempt)
else:
raise
return new_func
return decorator
######################
## HTTPS IO helpers ##
######################
@cache
def _get_http_session() -> requests.Session:
"""
Get a shared HTTP session with connection pooling.
This prevents resource exhaustion when making many HTTP requests.
"""
session = requests.Session()
# Configure connection pooling to reuse connections
adapter = HTTPAdapter(pool_maxsize=50)
session.mount("http://", adapter)
session.mount("https://", adapter)
return session
@retriable()
def _http_file_size(url: str) -> int:
session = _get_http_session()
response = session.head(url, allow_redirects=True)
content_length = response.headers.get("content-length")
if content_length is None:
raise OLMoNetworkError(
f"No content-length header found for {url}. "
f"This can happen when the server is rate-limiting requests or when DDoS protection flags this request. "
f"Headers: {dict(response.headers)}"
)
return int(content_length)
@retriable(
retry_condition=lambda exc: (
isinstance(exc, requests.exceptions.HTTPError)
and exc.response is not None
and exc.response.status_code == 502
),
)
def _http_get_bytes_range(url: str, bytes_start: int, num_bytes: int) -> bytes:
session = _get_http_session()
response = session.get(
url, headers={"Range": f"bytes={bytes_start}-{bytes_start + num_bytes - 1}"}
)
if response.status_code == 404:
raise FileNotFoundError(url)
response.raise_for_status()
result = response.content
# Some web servers silently ignore range requests and send everything
assert len(result) == num_bytes, f"expected {num_bytes} bytes, got {len(result)}"
return result
@retriable()
def _http_file_exists(url: str) -> bool:
session = _get_http_session()
response = session.head(url)
if response.status_code == 404:
return False
response.raise_for_status()
return True
####################
## GCS IO helpers ##
####################
@cache
def _get_gcs_client():
import google.auth
import google.auth.exceptions
from google.cloud import storage as gcs
try:
google.auth.default()
return gcs.Client()
except google.auth.exceptions.DefaultCredentialsError:
return gcs.Client.create_anonymous_client()
def _gcs_is_retriable(exc: Exception) -> bool:
from google.api_core.exceptions import BadRequest, GatewayTimeout
from google.api_core.retry import if_transient_error
from google.auth.exceptions import RefreshError
return if_transient_error(exc) or isinstance(
exc,
(
requests.exceptions.Timeout,
BadRequest, # Weird choice, but Google throws this transiently
GatewayTimeout,
RefreshError,
),
)
def _get_gcs_retry():
from google.api_core.retry import Retry
return Retry(
predicate=_gcs_is_retriable, # NOTE: it appears google might ignore this
initial=1.0,
maximum=10.0,
multiplier=2.0,
timeout=600.0,
)
@retriable(retry_condition=_gcs_is_retriable)
def _gcs_file_size(bucket_name: str, key: str) -> int:
from google.api_core.exceptions import NotFound
storage_client = _get_gcs_client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(key)
try:
blob.reload(retry=_get_gcs_retry())
except NotFound:
raise FileNotFoundError(f"gs://{bucket_name}/{key}")
assert blob.size is not None
return blob.size
@retriable(retry_condition=_gcs_is_retriable)
def _gcs_remove_file(bucket_name: str, key: str):
from google.api_core.exceptions import NotFound
storage_client = _get_gcs_client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(key)
try:
blob.reload(retry=_get_gcs_retry())
bucket.delete_blob(blob.name)
except NotFound:
raise FileNotFoundError(f"gs://{bucket_name}/{key}")
@retriable(retry_condition=_gcs_is_retriable)
def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes: int) -> bytes:
from google.api_core.exceptions import NotFound
storage_client = _get_gcs_client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(key)
try:
blob.reload(retry=_get_gcs_retry())
except NotFound:
raise FileNotFoundError(f"gs://{bucket_name}/{key}")
return blob.download_as_bytes(
start=bytes_start,
end=bytes_start + num_bytes - 1,
retry=_get_gcs_retry(),
checksum=None, # type: ignore
)
@retriable(retry_condition=_gcs_is_retriable)
def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False):
storage_client = _get_gcs_client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(key)
if blob.exists(retry=_get_gcs_retry()) and not save_overwrite:
raise FileExistsError(
f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
)
try:
blob.upload_from_filename(
source,
# NOTE: mypy and language servers may complain about the type here, but it does
# not in fact need to be a ConditionalRetry, a plain old Retry is fine.
retry=_get_gcs_retry(), # type: ignore
)
except Exception as e:
raise OLMoUploadError(
f"Failed to upload '{source}' to '{key}' in GCS bucket '{bucket_name}'"
) from e
@retriable(retry_condition=_gcs_is_retriable)
def _gcs_clear_directory(bucket_name: str, prefix: str):
from google.api_core.exceptions import NotFound
storage_client = _get_gcs_client()
prefix = prefix.strip("/")
if prefix:
prefix += "/"
try:
bucket = storage_client.bucket(bucket_name)
blobs = bucket.list_blobs(prefix=prefix, retry=_get_gcs_retry())
for blob in blobs:
bucket.delete_blob(blob.name)
except NotFound:
return
def _gcs_list_directory(
bucket_name: str,
prefix: str,
recurse: bool = False,
include_files: bool = True,
include_dirs: bool = True,
) -> Generator[str, None, None]:
from google.api_core.exceptions import NotFound
storage_client = _get_gcs_client()
prefix = prefix.strip("/")
if prefix:
prefix += "/"
for match_glob in (
prefix + "*", # only immediate files
prefix + "**/", # only immediate sub-folders
):
try:
bucket = storage_client.bucket(bucket_name)
blobs = bucket.list_blobs(
prefix=prefix,
delimiter="/",
match_glob=match_glob,
retry=_get_gcs_retry(),
)
except NotFound:
raise FileNotFoundError(f"gs://{bucket_name}/{prefix}")
# NOTE: need to iterate over these blobs even if not yielding files, otherwise 'blobs.prefixes'
# won't be populated.
for blob in blobs:
if include_files:
yield f"gs://{bucket_name}/{blob.name}"
for folder in blobs.prefixes:
if include_dirs:
yield f"gs://{bucket_name}/{folder.strip('/')}"
if recurse:
yield from _gcs_list_directory(
bucket_name,
folder,
recurse=True,
include_files=include_files,
include_dirs=include_dirs,
)
###################
## S3 IO helpers ##
###################
@cache
def _get_s3_client(scheme: str):
import boto3
from botocore import UNSIGNED
from botocore.config import Config
session = boto3.Session(profile_name=_get_s3_profile_name(scheme))
credentials = session.get_credentials()
config = Config(retries={"max_attempts": 10, "mode": "standard"})
if credentials is None:
config = config.merge(Config(signature_version=UNSIGNED))
return session.client(
"s3",
endpoint_url=_get_s3_endpoint_url(scheme),
config=config,
use_ssl=not int(os.environ.get("OLMO_NO_SSL", "0")),
)
def _get_s3_profile_name(scheme: str) -> Optional[str]:
if scheme == "s3":
# For backwards compatibility, we assume S3 uses the default profile if S3_PROFILE is not set.
return os.environ.get("S3_PROFILE")
if scheme == "r2":
profile_name = os.environ.get("R2_PROFILE")
if profile_name is None:
raise OLMoEnvironmentError(
"R2 profile name is not set. Did you forget to set the 'R2_PROFILE' env var?"
)
return profile_name
if scheme == "weka":
profile_name = os.environ.get("WEKA_PROFILE")
if profile_name is None:
raise OLMoEnvironmentError(
"WEKA profile name is not set. Did you forget to set the 'WEKA_PROFILE' env var?"
)
return profile_name
raise NotImplementedError(f"Cannot get profile name for scheme {scheme}")
def _get_s3_endpoint_url(scheme: str) -> Optional[str]:
if scheme == "s3":
return None
if scheme == "r2":
r2_endpoint_url = os.environ.get("R2_ENDPOINT_URL")
if r2_endpoint_url is None:
raise OLMoEnvironmentError(
"R2 endpoint url is not set. Did you forget to set the 'R2_ENDPOINT_URL' env var?"
)
return r2_endpoint_url
if scheme == "weka":
weka_endpoint_url = os.environ.get("WEKA_ENDPOINT_URL")
if weka_endpoint_url is None:
raise OLMoEnvironmentError(
"WEKA endpoint url is not set. Did you forget to set the 'WEKA_ENDPOINT_URL' env var?"
)
return weka_endpoint_url
raise NotImplementedError(f"Cannot get endpoint url for scheme {scheme}")
def _s3_retry_condition(err: Exception) -> bool:
import botocore.exceptions as boto_errors
import urllib3.exceptions as urllib3_errors
return isinstance(
err,
(
boto_errors.ClientError,
boto_errors.HTTPClientError,
boto_errors.ConnectionError,
boto_errors.SSLError,
urllib3_errors.SSLError,
ssl.SSLError,
),
)
@retriable(retry_condition=_s3_retry_condition)
def _s3_file_size(scheme: str, bucket_name: str, key: str) -> int:
from botocore.exceptions import ClientError
try:
return _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)["ContentLength"]
except ClientError as e:
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
raise FileNotFoundError(f"{scheme}://{bucket_name}/{key}") from e
else:
raise
@retriable(retry_condition=_s3_retry_condition)
def _s3_remove_file(scheme: str, bucket_name: str, key: str):
from botocore.exceptions import ClientError
try:
return _get_s3_client(scheme).delete_object(Bucket=bucket_name, Key=key)
except ClientError as e:
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
raise FileNotFoundError(f"{scheme}://{bucket_name}/{key}") from e
else:
raise
@retriable(retry_condition=_s3_retry_condition)
def _s3_get_bytes_range(
scheme: str, bucket_name: str, key: str, bytes_start: int, num_bytes: int
) -> bytes:
from botocore.exceptions import ClientError
try:
return (
_get_s3_client(scheme)
.get_object(
Bucket=bucket_name,
Key=key,
Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}",
)["Body"]
.read()
)
except ClientError as e:
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
raise FileNotFoundError(f"{scheme}://{bucket_name}/{key}") from e
elif e.response["Error"]["Code"] == "InvalidRange":
raise OLMoInvalidRangeRequestError(
f"Invalid range request to '{scheme}://{bucket_name}/{key}' ({bytes_start=}, {num_bytes=})"
)
else:
raise
@retriable(retry_condition=_s3_retry_condition)
def _s3_upload(
source: Path,
scheme: str,
bucket_name: str,
key: str,
save_overwrite: bool = False,
):
from botocore.exceptions import ClientError
if not save_overwrite:
try:
_get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)
raise FileExistsError(
f"s3://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
)
except ClientError as e:
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
pass
else:
raise
_get_s3_client(scheme).upload_file(source, bucket_name, key)
@retriable(retry_condition=_s3_retry_condition)
def _s3_clear_directory(scheme: str, bucket_name: str, prefix: str):
from botocore.exceptions import ClientError
if not prefix.endswith("/"):
prefix = prefix + "/"
try:
for o in _get_s3_client(scheme).list_objects_v2(Bucket=bucket_name, Prefix=prefix)[
"Contents"
]:
_get_s3_client(scheme).delete_object(Bucket=bucket_name, Key=o["Key"])
return
except ClientError as e:
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
return
except KeyError:
return
def _s3_list_directory(
scheme: str,
bucket_name: str,
prefix: str,
recurse: bool = False,
include_files: bool = True,
include_dirs: bool = True,
) -> Generator[str, None, None]:
client = _get_s3_client(scheme)
paginator = client.get_paginator("list_objects_v2")
if not prefix.endswith("/"):
prefix = prefix + "/"
for page in paginator.paginate(Bucket=bucket_name, Prefix=prefix, MaxKeys=50, Delimiter="/"):
if include_files:
for file_item in page.get("Contents", []):
yield f"{scheme}://{bucket_name}/{file_item['Key']}"
for dir_item in page.get("CommonPrefixes", []):
if include_dirs:
yield f"{scheme}://{bucket_name}/{dir_item['Prefix'].strip('/')}"
if recurse:
yield from _s3_list_directory(
scheme,
bucket_name,
dir_item["Prefix"],
recurse=True,
include_files=include_files,
include_dirs=include_dirs,
)
#############################################
## Custom cached path client for 'weka://' ##
#############################################
[docs]
def add_cached_path_clients():
"""
Add additional cached-path clients.
"""
add_scheme_client(_WekaClient)
class _WekaClient(SchemeClient):
recoverable_errors = S3Client.recoverable_errors
scheme = "weka"
def __init__(self, resource: str) -> None:
super().__init__(resource)
self.bucket_name, self.path = _WekaClient._split_cloud_path(resource, "weka")
self.s3 = _get_s3_client("weka")
self.object_info = None
@staticmethod
def _split_cloud_path(url: str, provider: str) -> Tuple[str, str]:
"""Split a full s3 path into the bucket name and path."""
from urllib.parse import urlparse
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
raise ValueError("bad {} path {}".format(provider, url))
bucket_name = parsed.netloc
provider_path = parsed.path
# Remove '/' at beginning of path.
if provider_path.startswith("/"):
provider_path = provider_path[1:]
return bucket_name, provider_path
def _ensure_object_info(self):
import botocore.exceptions as boto_exceptions
if self.object_info is None:
try:
self.object_info = self.s3.head_object(Bucket=self.bucket_name, Key=self.path)
except boto_exceptions.ClientError as e:
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
raise FileNotFoundError(f"weka://{self.bucket_name}/{self.path}") from e
raise e
def get_etag(self) -> Optional[str]:
self._ensure_object_info()
assert self.object_info is not None
return self.object_info.get("ETag")
def get_size(self) -> Optional[int]:
self._ensure_object_info()
assert self.object_info is not None
return self.object_info.get("ContentLength")
def get_resource(self, temp_file: io.BufferedWriter) -> None:
self.s3.download_fileobj(Fileobj=temp_file, Bucket=self.bucket_name, Key=self.path)
def get_bytes_range(self, index: int, length: int) -> bytes:
response = self.s3.get_object(
Bucket=self.bucket_name, Key=self.path, Range=f"bytes={index}-{index + length - 1}"
)
return response["Body"].read()