import copy
import dataclasses
import json
from dataclasses import dataclass
from enum import Enum
from typing import (
Any,
Callable,
ClassVar,
Collection,
Dict,
Generator,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
)
import torch
import yaml
from cached_path import cached_path
from dataclass_extensions import Registrable, decode
from typing_extensions import Self
from .aliases import PathOrStr
from .exceptions import OLMoConfigurationError
__all__ = [
"Config",
"DType",
"StrEnum",
"UNSET",
"Registrable", # re-exported for convenience
]
[docs]
class StrEnum(str, Enum):
"""
This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
We include this here for compatibility with older version of Python.
"""
def __str__(self) -> str:
return self.value
def __repr__(self) -> str:
return f"'{str(self)}'"
C = TypeVar("C", bound="Config")
[docs]
@dataclass
class Config:
"""
A base class for configuration dataclasses.
.. important::
When you subclass this you should still decorate your subclasses with
:func:`@dataclass <dataclasses.dataclass>`. For example::
@dataclass
class MyConfig(Config):
...
.. important::
Config classes need to be serializable, so you should only use simple types for your fields.
Though you can use nested configs.
"""
CLASS_NAME_FIELD = "_CLASS_"
"""
The name of the class name field inject into the dictionary from :meth:`as_dict()` or
:meth:`as_config_dict()`.
"""
_IGNORE_FIELDS: ClassVar[Tuple[str, ...]] = ()
"""
Fields to ignore when loading from config (for backwards compatibility).
"""
[docs]
def as_dict(
self,
*,
exclude_none: bool = False,
exclude_private_fields: bool = False,
exclude: Optional[Collection[str]] = None,
include_class_name: bool = False,
include_registered_name: bool = False,
json_safe: bool = False,
recurse: bool = True,
) -> Dict[str, Any]:
"""
Convert into a regular Python dictionary.
:param exclude_none: Don't include values that are ``None``.
:param exclude_private_fields: Don't include private fields.
:param exclude: A list of field names to exclude.
:param include_class_name: Include a field for the name of the class.
:param include_registered_name: If the config is :class:`Registrable`, include the
registered name under the key "type".
:param json_safe: Output only JSON-safe types.
:param recurse: Recurse into fields that are also configs/dataclasses.
"""
exclude_set = set(exclude) if exclude is not None else set()
def iter_fields(d) -> Generator[Tuple[str, Any], None, None]:
for field in dataclasses.fields(d):
if field.name in exclude_set or not field.init:
continue
value = getattr(d, field.name)
if exclude_none and value is None:
continue
elif exclude_private_fields and field.name.startswith("_"):
continue
else:
yield (field.name, value)
def as_dict(d: Any, recurse: bool = True) -> Any:
if dataclasses.is_dataclass(d):
if recurse:
out = {k: as_dict(v) for k, v in iter_fields(d)}
else:
out = {k: v for k, v in iter_fields(d)}
if include_class_name:
out[self.CLASS_NAME_FIELD] = f"{d.__class__.__module__}.{d.__class__.__name__}"
if include_registered_name and isinstance(d, Registrable):
try:
registered_name = d.get_registered_name()
out["type"] = registered_name
except ValueError:
pass
return out
elif isinstance(d, dict):
return {k: as_dict(v) for k, v in d.items()}
elif isinstance(d, (list, tuple, set)):
if json_safe:
return [as_dict(x) for x in d]
else:
return d.__class__((as_dict(x) for x in d))
elif d is None or isinstance(d, (float, int, bool, str)):
return d
elif json_safe:
return str(d)
else:
return d
return as_dict(self, recurse=recurse)
[docs]
def as_config_dict(self) -> Dict[str, Any]:
"""
A convenience wrapper around :meth:`as_dict()` for creating JSON-safe dictionaries suitable
for recording the config.
"""
return self.as_dict(
exclude_none=True,
exclude_private_fields=True,
include_class_name=True,
include_registered_name=True,
json_safe=True,
recurse=True,
)
[docs]
def apply(self, func: Callable[["Config"], None]):
"""
Recursively apply a function to every config instance field, including ``self``.
:param func: The function to apply.
"""
def apply(d):
if isinstance(d, Config):
func(d)
if dataclasses.is_dataclass(d):
for field in dataclasses.fields(d):
value = getattr(d, field.name)
apply(value)
elif isinstance(d, dict):
for value in d.values():
apply(value)
elif isinstance(d, (list, tuple, set)):
for x in d:
apply(x)
apply(self)
[docs]
def validate(self):
"""
Validate fields in ``self``. This may modify ``self`` in-place.
"""
pass
[docs]
def merge(self, dotlist: List[str], prefix: Optional[str] = None, strict: bool = True) -> Self:
"""
Merge self with fields from a "dotlist", creating a new object.
:param dotlist: A list of field attributes with dot notation, e.g. ``foo.bar=1``.
:param prefix: Only use override items in the dotlist that start with a given prefix name,
and strip that prefix (including the subsequent ".") before applying the overrides.
:param strict: Parse the dotlist strictly.
"""
overrides = _clean_opts(dotlist)
if prefix is not None:
overrides = [
(k.replace(f"{prefix}.", "", 1), v)
for k, v in overrides
if k.startswith(f"{prefix}.")
]
if not strict:
field_names = set(f.name for f in dataclasses.fields(self))
overrides = [
(k, v)
for k, v in overrides
if any([k == name or k.startswith(f"{name}.") for name in field_names])
]
merged_data = self.as_dict(include_class_name=True, include_registered_name=True)
for key, value in overrides:
_set_nested(merged_data, key, value)
return self.from_dict(merged_data)
[docs]
def replace(self, **changes) -> Self:
"""
Creates a new object of the same type, replacing fields with values from ``changes``.
"""
return dataclasses.replace(self, **changes)
[docs]
def copy(self, deep: bool = True) -> Self:
"""
Creates a new object of the same type, with the same values.
"""
return copy.deepcopy(self) if deep else copy.copy(self)
[docs]
@classmethod
def from_dict(cls: Type[C], data: Dict[str, Any], overrides: Optional[List[str]] = None) -> C:
"""
Initialize from a regular Python dictionary.
:param data: A Python dictionary.
:param overrides: A list of field overrides with dot notation, e.g. ``foo.bar=1``.
"""
from importlib import import_module
def resolve_cls(cls_name: str) -> Optional[Any]:
if "." in cls_name:
*modules, cls_name = cls_name.split(".")
module_name = ".".join(modules)
module = import_module(module_name)
return getattr(module, cls_name, None)
else:
return None
def decode_data(d: Any, prefix: str) -> Any:
if isinstance(d, dict):
# HACK: Try to convert string keys to int if they look like integers. Handles cases
# where integer keys were serialized as strings (eg "block_overrides")
d = {(int(k) if isinstance(k, str) and k.isdigit() else k): v for k, v in d.items()}
new_dict = {
k: decode_data(v, f"{prefix}.{k}" if prefix else str(k))
for k, v in d.items()
if k != cls.CLASS_NAME_FIELD
}
if (cls_name := d.get(cls.CLASS_NAME_FIELD)) is not None and (
cls_o := resolve_cls(cls_name)
) is not None:
# Remove ignored fields if the class defines any
if (ignore_fields := getattr(cls_o, "_IGNORE_FIELDS", None)) is not None:
new_dict = {k: v for k, v in new_dict.items() if k not in ignore_fields}
# Remove the "type" field since the class is already resolved via _CLASS_.
# This avoids a registry lookup on the resolved subclass, whose own
# _registry may be empty (registrations live on the parent class).
new_dict.pop("type", None)
try:
return decode(cls_o, new_dict) # type: ignore[arg-type]
except Exception as e:
if prefix:
msg = f"Failed to construct '{prefix}' in config: {e}"
else:
msg = f"Error building config: {e}"
raise OLMoConfigurationError(msg) from e
return new_dict
elif isinstance(d, (list, tuple, set)):
return d.__class__(
(decode_data(x, f"{prefix}.{i}" if prefix else str(i)) for i, x in enumerate(d))
)
else:
return d
if overrides:
for key, value in _clean_opts(overrides):
_set_nested(data, key, value)
decoded = decode_data(data, "")
if isinstance(decoded, cls):
return decoded
else:
return decode(cls, decoded)
@classmethod
def from_file(cls: Type[C], path: PathOrStr, overrides: Optional[List[str]] = None) -> C:
path_str = str(path)
if path_str.endswith((".yml", ".yaml")):
return cls.from_yaml(path, overrides=overrides)
elif path_str.endswith(".json"):
return cls.from_json(path, overrides=overrides)
else:
raise OLMoConfigurationError(f"Unsupported config file type: {path}")
@classmethod
def from_json(cls: Type[C], path: PathOrStr, overrides: Optional[List[str]] = None) -> C:
with cached_path(path).open() as f:
config_dict = json.load(f)
return cls.from_dict(config_dict, overrides=overrides)
@classmethod
def from_yaml(cls: Type[C], path: PathOrStr, overrides: Optional[List[str]] = None) -> C:
with cached_path(path).open() as f:
config_dict = yaml.safe_load(f)
return cls.from_dict(config_dict, overrides=overrides)
def _set_nested(data: Any, key: str, value: Any):
if "." in key:
key, child_keys = key.split(".", 1)
if isinstance(data, dict):
_set_nested(data[key], child_keys, value)
elif isinstance(data, list):
_set_nested(data[int(key)], child_keys, value)
else:
raise ValueError(data)
else:
if isinstance(data, dict):
data[key] = value
elif isinstance(data, list):
data[int(key)] = value
else:
raise ValueError(f"Can't set value '{value}' at key '{key}' for object {data}")
def _clean_opts(opts: Sequence[str]) -> list[tuple[str, Any]]:
return [_clean_opt(s) for s in opts]
def _clean_opt(arg: str) -> tuple[str, Any]:
if "=" not in arg:
name, val = arg, "true"
else:
name, val = arg.split("=", 1)
name = name.strip(" -").replace("-", "_")
if not val or val.isspace():
val = ""
else:
val = yaml.safe_load(val)
return (name, val)
[docs]
class DType(StrEnum):
"""
An enumeration of supported PyTorch data types.
"""
float32 = "float32"
bfloat16 = "bfloat16"
float16 = "float16"
float8_e4m3fn = "float8_e4m3fn" # note: other e4m3 variants are supported in torch
float8_e5m2 = "float8_e5m2"
@classmethod
def from_pt(cls, dtype: torch.dtype) -> "DType":
if dtype == torch.float32:
return DType.float32
elif dtype == torch.bfloat16:
return DType.bfloat16
elif dtype == torch.float16:
return DType.float16
elif dtype == torch.float8_e4m3fn:
return DType.float8_e4m3fn
elif dtype == torch.float8_e5m2:
return DType.float8_e5m2
else:
raise NotImplementedError(dtype)
def as_pt(self) -> torch.dtype:
return getattr(torch, self)
class _Unset:
"""
Sentinel value indicating that a value was not explicitly provided.
Used internally to detect when a value should be skipped.
"""
def __repr__(self) -> str:
return "<UNSET>"
UNSET: Any = _Unset()