Source code for olmo_core.nn.conversion.state_converter

import itertools
from dataclasses import dataclass
from typing import Any, Dict, List

import torch

from olmo_core.doc_utils import beta_feature
from olmo_core.nn.conversion.state_mapping import (
    StateMapping,
    StateMappingTemplate,
    StateType,
    TemplatePlaceholder,
)


[docs] @beta_feature @dataclass class StateConverter: """ A class for converting state from one format to another format (e.g. OLMo Core to HF). """ mapping_templates: List[StateMappingTemplate] def _fill_placeholders( self, mapping: StateMappingTemplate, placeholder_values: Dict[TemplatePlaceholder, int | None], placeholder_bounds: Dict[TemplatePlaceholder, int], ) -> StateMapping | None: return mapping.to_mapping(placeholder_values, placeholder_bounds) def _get_mappings( self, state_dict: Dict[str, Any], placeholder_bounds: Dict[TemplatePlaceholder, int], state_type: StateType = StateType.weight, ) -> List[StateMapping]: # We consider all combinations of placeholders, including allowing each placeholder to not be set. # If a placeholder is set when not need, the combination will be treated as invalid # and so ignored. placeholder_value_combinations: List[Dict[TemplatePlaceholder, int | None]] = list( map( dict, itertools.product( *[ [(placeholder, i) for i in range(bound)] + [(placeholder, None)] for placeholder, bound in placeholder_bounds.items() ] ), ) ) # Fill in the placeholders in the mapping templates state_mappings = [ self._fill_placeholders( mapping_template, placeholder_value_combination, placeholder_bounds, ) for mapping_template in self.mapping_templates if mapping_template.state_type == state_type for placeholder_value_combination in placeholder_value_combinations ] # Filter for mappings that are relevant to the given state dict state_keys = set(state_dict.keys()) relevant_state_mappings = [ mapping for mapping in state_mappings if mapping is not None and all(k in state_keys for k in mapping.source_keys) ] return relevant_state_mappings
[docs] def get_mappings( self, state_dict: Dict[str, Any], placeholder_bounds: Dict[TemplatePlaceholder, int], state_type: StateType = StateType.weight, ) -> List[StateMapping]: """ Gets the state mapping from the given state dict to the converted format, without performing conversion. :param state_dict: The state dictionary in unconverted format. :param placeholder_bounds: Upper bound values for any relevant placeholders (e.g. for ``TemplatePlaceholder.EXPERT``, the number of experts). :param state_type: The type of state this state dict corresponds to. Defaults to ``StateType.weight``. """ return self._get_mappings(state_dict, placeholder_bounds, state_type=state_type)
[docs] def convert( self, state_dict: Dict[str, Any], placeholder_bounds: Dict[TemplatePlaceholder, int], state_type: StateType = StateType.weight, ) -> Dict[str, Any]: """ Converts a state dict to another format. This currently only supports tensor values. :param state_dict: The state dictionary to convert. :param placeholder_bounds: Upper bound values for any relevant placeholders (e.g. for ``TemplatePlaceholder.EXPERT``, the number of experts). :param state_type: The type of state this state dict corresponds to. Defaults to ``StateType.weight``. """ state_mappings = self._get_mappings(state_dict, placeholder_bounds, state_type=state_type) unused_original_keys = set(state_dict.keys()) converted_state_dict = {} for mapping in state_mappings: original_keys = mapping.source_keys converted_keys = mapping.dest_keys if isinstance(state_dict[original_keys[0]], torch.Tensor): original_state = torch.cat( [state_dict[key] for key in original_keys], dim=mapping.source_concat_dim, ) if mapping.unflatten_dim is not None: original_state = original_state.unflatten(*mapping.unflatten_dim) if mapping.dims_permutation is not None: original_state = original_state.permute(*mapping.dims_permutation) if mapping.flatten_dims is not None: original_state = original_state.flatten(*mapping.flatten_dims) state_chunks = torch.chunk( original_state, chunks=len(converted_keys), dim=mapping.dest_chunk_dim ) for hf_key, state_chunk in zip(converted_keys, state_chunks): converted_state_dict[hf_key] = state_chunk.contiguous() else: raise RuntimeError( f"Attempting to map {len(original_keys)} non-tensor states to {len(converted_keys)} keys" ) unused_original_keys -= set(original_keys) if len(unused_original_keys) > 0: raise RuntimeError( f"Some state keys were not converted: {sorted(unused_original_keys)}" ) return converted_state_dict