Source code for olmo_core.generate.generation_module.generation_module

import logging
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, Optional

import torch.distributed as dist
from torch.distributed.checkpoint.metadata import Metadata
from torch.distributed.checkpoint.stateful import Stateful

log = logging.getLogger(__name__)


[docs] class GenerationModule(Stateful, metaclass=ABCMeta): @property def dp_process_group(self) -> Optional[dist.ProcessGroup]: """ Should return the data parallel process group if it's anything other than the default process group. """ return None def state_dict_to_load(self, metadata: Metadata) -> Dict[str, Any]: del metadata return self.state_dict()
[docs] @abstractmethod def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """ Load a state dict. """ raise NotImplementedError