Source code for olmo_core.nn.conversion.state_mapping

from dataclasses import dataclass
from typing import Any, Dict, Optional, Set, Tuple

from olmo_core.config import StrEnum


[docs] class TemplatePlaceholder(StrEnum): """ A placeholder that can be used in the templates of :class:`StateMappingTemplate`. """ LAYER = "[layer]" """""" EXPERT = "[expert]" """"""
class StateType(StrEnum): """ The category the state being converted corresponds to. """ weight = "weight" """ The state being converted corresponds to a weight. This is useful for converting between checkpoints, where the state is the weight itself. """ module = "module" """ The state being converted corresponds to a modules. This can be useful for comparing activations between different implementations of the same model, where the states are the activations of submodules. """
[docs] @dataclass class StateMappingTemplate: """ The template for a mapping state from one format to another format (e.g. OLMo Core to HF). These mappings are 'templates' since they support keys and other metadata having placeholders for information like the layer number or number of MoE experts. This class can be converted to a :class:`StateMapping` by providing the placeholder information. The most standard mapping is a one-to-one state mapping, which corresponds to a single string entry for both :data:`source_template_keys` and :data:`dest_template_keys`. The class also supports more complicated mappings, like many-to-many mappings or mappings that also require further manipulations of state like permuting dimensions. """ source_template_keys: str | Tuple[str, ...] """ The key or keys of the state(s) being mapping from. """ dest_template_keys: str | Tuple[str, ...] """ The key or keys of the state(s) being mapping to. """ state_type: StateType = StateType.weight source_key_per_placeholder: TemplatePlaceholder | None = None """ A placeholder in :data:`source_template_keys` for which this mapping should map all valid placeholder values, rather than 1 specific value. For example, this enables mapping states from all experts (using ``TemplatePlaceholder.EXPERT``) to a single state. When provided, :data:`source_template_keys` must be a string. """ dest_key_per_placeholder: TemplatePlaceholder | None = None """ A placeholder in :data:`dest_template_keys` for which this mapping should map all valid placeholder values, rather than 1 specific value. For example, this enables mapping from a single state to states from all experts (using ``TemplatePlaceholder.EXPERT``). When provided, :data:`dest_template_keys` must be a string. """ source_concat_dim: int = 0 """ When many states are being mapping from, this specifies the dimension on which to combine them. """ unflatten_dim: Tuple[int, Tuple[TemplatePlaceholder | int, ...]] | None = None """ This specifies that the given dimension (``unflatten_dim[0]``) should be unflattened using the shape given in ``unflatten_dim[1]``. A placeholder can be given instead of a number, to represent its corresponding upper bound (e.g. ``TemplatePlaceholder.EXPERT`` represents the number of experts). """ dims_permutation: Tuple[int, ...] | None = None """ This specifies the permutation that should be applied to the dimensions of the state after any unflattening from :data:`unflatten_dim` has occurred. """ flatten_dims: Tuple[int, int] | None = None """ This specifies that all the dimensions between the 2 given dimensions (inclusive) should be flattened, after any permutations from :data:`dims_permutation` have been applied. """ dest_chunk_dim: int = 0 """ When many states are being mapping to, this specifies the dimension on which to (evenly) chunk them. """ def __post_init__(self): if self.source_key_per_placeholder and isinstance(self.source_template_keys, tuple): raise ValueError( f"Having a key per {self.source_key_per_placeholder} is not supported with multiple template keys" ) if self.dest_key_per_placeholder and isinstance(self.dest_template_keys, tuple): raise ValueError( f"Having a key per {self.dest_key_per_placeholder} is not supported with multiple template keys" ) def _templates_to_keys( self, placeholder_values: Dict[TemplatePlaceholder, Any], placeholder_bounds: Dict[TemplatePlaceholder, int], *, source: bool, ) -> Tuple[str, ...] | None: if source: templates = self.source_template_keys key_per_placeholder = self.source_key_per_placeholder else: templates = self.dest_template_keys key_per_placeholder = self.dest_key_per_placeholder if key_per_placeholder: if not isinstance(templates, str): raise ValueError( "Invalid template; template must be a string when expanding a placeholder" ) template = templates if key_per_placeholder not in template: raise ValueError( f"Invalid template; placeholder {key_per_placeholder} is being expanded but is not present in template {template}" ) if key_per_placeholder not in placeholder_bounds: raise ValueError( f"Invalid bounds; placeholder {key_per_placeholder} does not have a bound" ) key_per_placeholder_values = list(range(placeholder_bounds[key_per_placeholder])) templates = tuple( template.replace(key_per_placeholder, str(value)) for value in key_per_placeholder_values ) elif isinstance(templates, str): templates = (templates,) assert isinstance(templates, tuple) keys = [] for template in templates: key = template for placeholder, value in placeholder_values.items(): if placeholder in template and value is not None: key = key.replace(placeholder, str(value)) elif placeholder not in template and value is None: pass else: # If a placeholder is given a value but is not present, # we treat the placeholder values as invalid. # Similarly, if a placeholder is not given a value but is present, # we treat the placeholder values as invalid. return None keys.append(key) for key in keys: if any(placeholder in key for placeholder in TemplatePlaceholder): # If a placeholder has not been filled, its key was not provided. return None return tuple(keys) def to_mapping( self, placeholder_values: Dict[TemplatePlaceholder, int | None], placeholder_bounds: Dict[TemplatePlaceholder, int], ) -> Optional["StateMapping"]: required_placeholders: Set[TemplatePlaceholder | None] = set() if self.source_key_per_placeholder: required_placeholders.add(self.source_key_per_placeholder) if self.dest_key_per_placeholder: required_placeholders.add(self.dest_key_per_placeholder) if self.unflatten_dim: required_placeholders.update( [dim for dim in self.unflatten_dim[1] if isinstance(dim, TemplatePlaceholder)] ) missing_required_placeholders = required_placeholders.difference(placeholder_bounds.keys()) if missing_required_placeholders: # This may be, say, an MoE mapping for which we do not have any expert values. # This is ok; we simply discard this mapping. return None source_keys = self._templates_to_keys( placeholder_values, placeholder_bounds, source=True, ) dest_keys = self._templates_to_keys( placeholder_values, placeholder_bounds, source=False, ) if source_keys is None or dest_keys is None: return None unflatten_dim = None if self.unflatten_dim is not None: unflatten_dim_shape = tuple( placeholder_bounds[dim] if isinstance(dim, TemplatePlaceholder) else int(dim) for dim in self.unflatten_dim[1] ) unflatten_dim = (self.unflatten_dim[0], unflatten_dim_shape) return StateMapping( source_keys, dest_keys, state_type=self.state_type, source_concat_dim=self.source_concat_dim, unflatten_dim=unflatten_dim, dims_permutation=self.dims_permutation, flatten_dims=self.flatten_dims, dest_chunk_dim=self.dest_chunk_dim, )
[docs] @dataclass class StateMapping: """ A mapping from state from one format to another format (e.g. OLMo Core to HF). The most standard mapping is a one-to-one state mapping, which corresponds to a single string entry for both :data:`source_keys` and :data:`dest_keys`. The class also supports more complicated mappings, like many-to-many mappings or mappings that also require further manipulations of state like permuting dimensions. """ source_keys: Tuple[str, ...] """ The key(s) of the state(s) being mapping from. """ dest_keys: Tuple[str, ...] """ The key or keys of the state(s) being mapping to. """ state_type: StateType = StateType.weight source_concat_dim: int = 0 """ When many states are being mapping from, this specifies the dimension on which to combine them. """ unflatten_dim: Tuple[int, Tuple[int, ...]] | None = None """ This specifies that the given dimension (``unflatten_dim[0]``) should be unflattened using the shape given in ``unflatten_dim[1]``. """ dims_permutation: Tuple[int, ...] | None = None """ This specifies the permutation that should be applied to the dimensions of the state after any unflattening from :data:`unflatten_dim` has occurred. """ flatten_dims: Tuple[int, int] | None = None """ This specifies that all the dimensions between the 2 given dimensions (inclusive) should be flattened, after any permutations from :data:`dims_permutation` have been applied. """ dest_chunk_dim: int = 0 """ When many states are being mapping to, this specifies the dimension on which to (evenly) chunk them. """