Source code for dataclass_extensions.registrable

from __future__ import annotations

import dataclasses
import sys
import typing
from dataclasses import dataclass
from typing import Callable, ClassVar, Type, TypeVar

R = TypeVar("R", bound="Registrable")

if sys.version_info < (3, 11):
    # HACK: work-around for https://stackoverflow.com/q/70400639/4151392
    dataclasses.InitVar.__call__ = lambda *_: None  # type: ignore


[docs] @dataclass class Registrable: _registry: ClassVar[dict[str, Type[Registrable]]] _default_type: ClassVar[str | None] registered_name: ClassVar[str | None] registered_base: ClassVar[Type[Registrable] | None] type: dataclasses.InitVar[str | None] = dataclasses.field( default=None, kw_only=True, repr=False ) def __new__(cls, *args, type: str | None = None, **kwargs): del args, kwargs if type is not None and type != cls.registered_name: if type not in cls._registry: raise KeyError( f"'{type}' is not registered name for {cls.__name__}. " f"Available choices are: {list(cls._registry.keys())}" ) return super().__new__(cls._registry[type]) elif cls._default_type is not None and cls.registered_name is None: return super().__new__(cls._registry[cls._default_type]) else: return super().__new__(cls) def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) cls._registry = {} cls._default_type = None if not hasattr(cls, "registered_name"): cls.registered_name = None if not hasattr(cls, "registered_base"): cls.registered_base = None @classmethod def register(cls, name: str, default: bool = False) -> Callable[[Type[R]], Type[R]]: if cls is Registrable: raise TypeError("Cannot register with the base Registrable class itself") def register_subclass(subclass: Type[R]) -> Type[R]: if subclass is Registrable: raise TypeError("Cannot register the base Registrable class itself") if not issubclass(subclass, cls): raise TypeError( f"class {subclass.__name__} must be a subclass of {cls.__name__} in order to register it" ) if not dataclasses.is_dataclass(subclass): raise TypeError( f"class {subclass.__name__} must be a dataclass in order to register it" ) if default: if cls._default_type is not None: raise ValueError( f"A default implementation for {cls.__name__} has already been registered" ) else: cls._default_type = name cls._registry[name] = subclass subclass.registered_name = name subclass.registered_base = cls return subclass # type: ignore return register_subclass @classmethod def get_registered_name(cls: Type[R], subclass: Type[R] | None = None) -> str: if cls is Registrable: raise TypeError("Cannot register the base Registrable class itself") if subclass is None: if cls.registered_name is not None: return cls.registered_name else: raise ValueError( f"class {cls.__name__} is not a registered subclass of any base registrable class" ) for name, registered_subclass in cls._registry.items(): if registered_subclass == subclass: return name raise ValueError( f"class {subclass.__name__} is not a registered subclass of {cls.__name__}" ) @classmethod def get_registered_class(cls: Type[R], type: str) -> Type[R]: if cls is Registrable: raise TypeError("Cannot register with the base Registrable class itself") if type not in cls._registry: raise KeyError( f"'{type}' is not registered name for {cls.__name__}. " f"Available choices are: {cls.get_registered_names()}" ) return typing.cast(Type[R], cls._registry[type]) @classmethod def get_registered_names(cls) -> list[str]: if cls is Registrable: raise TypeError("Cannot register with the base Registrable class itself") return list(cls._registry.keys()) @classmethod def get_default(cls: Type[R]) -> Type[R]: if cls is Registrable: raise TypeError("Cannot register with the base Registrable class itself") if cls._default_type is None: raise ValueError(f"A default implementation of {cls.__name__} has not been registered") return cls.get_registered_class(cls._default_type)