import logging
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, Generator, Optional
import torch
import torch.distributed as dist
from huggingface_hub import repo_exists
from torch.distributed.tensor import DTensor, distribute_tensor
from transformers import AutoModelForCausalLM, AutoTokenizer
from olmo_core.aliases import PathOrStr
from olmo_core.config import DType
from olmo_core.distributed.utils import barrier, get_fs_local_rank, get_full_tensor
from olmo_core.doc_utils import beta_feature
from olmo_core.io import clear_directory, copy_dir, file_exists, is_url
from olmo_core.nn.hf.config import (
get_hf_config,
get_hybrid_hf_config,
get_hybrid_layer_types,
)
from olmo_core.nn.hf.convert import (
convert_hybrid_state_to_hf,
convert_state_from_hf,
convert_state_to_hf,
)
from olmo_core.nn.transformer.model import Transformer
try:
from accelerate import init_empty_weights # type: ignore
except ImportError:
@contextmanager
def init_empty_weights(include_buffers: bool = False) -> Generator[None, None, None]:
del include_buffers
log.warning("accelerate not installed, will initialize weights.")
yield None
log = logging.getLogger(__name__)
[docs]
@beta_feature
def load_hf_model(
model_name_or_path: PathOrStr,
model_state_dict: Dict[str, Any],
*,
revision: str = "main",
model_id: Optional[str] = None,
num_embeddings: Optional[int] = None,
process_group: Optional[dist.ProcessGroup] = None,
work_dir: Optional[PathOrStr] = None,
):
"""
Loads an OLMo Core model state dict using a model in Hugging Face transformers format.
:param model_name_or_path: The name of a model in HF Hub or the path to a model saved in HF format.
:param model_state_dict: The OLMo Core model state dict in which to load HF state.
:param revision: If ``model_name_or_path`` is the id of a model in HF Hub, then this is the revision
(branch) of that model. Defaults to "main".
:param model_id: Deprecated, model-specific mappings are now determined by the model architecture,
in :mod:`olmo_core.nn.hf.convert`
:param num_embeddings: The number of embeddings in the OLMo Core model being loaded into,
defaults to the number of embeddings in the HF model.
:param process_group: The process group to use for distributed communication.
:param work_dir: A local directory that can be used for holding temporary state. Required when
downloading a model from a cloud directory.
"""
del model_id
work_dir = f"{work_dir}/hf-tmp" if work_dir is not None else None
if is_url(model_name_or_path):
log.warning(
"Model id or path provided is a remote Hugging Face directory. This may not be suitable for unshared file systems."
)
assert work_dir is not None
assert (
file_exists(f"{model_name_or_path}/generation_config.json")
or file_exists(f"{model_name_or_path}/model.safetensors.index.json")
or file_exists(f"{model_name_or_path}/pytorch_model.bin")
)
# Download model to local FS
if get_fs_local_rank() == 0:
copy_dir(model_name_or_path, work_dir)
barrier(group=process_group)
elif Path(model_name_or_path).is_dir():
assert (
file_exists(f"{model_name_or_path}/generation_config.json")
or file_exists(f"{model_name_or_path}/model.safetensors.index.json")
or file_exists(f"{model_name_or_path}/pytorch_model.bin")
)
elif repo_exists(str(model_name_or_path)):
log.warning(
"Model id or path provided is a Hugging Face model id. This may not be suitable for unshared file systems."
)
else:
raise NotImplementedError
# Warm up the HF local cache by downloading the model on just local rank 0
if get_fs_local_rank() == 0:
hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, revision=revision)
del hf_model
barrier(group=process_group)
hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, revision=revision)
log.info(f"Loaded hf model: {hf_model}")
hf_model.resize_token_embeddings(num_embeddings)
converted_state_dict: Dict[str, torch.Tensor] = convert_state_from_hf(
hf_model.config,
hf_model.state_dict(),
model_type=getattr(hf_model.config, "model_type", None),
)
for key in sorted(converted_state_dict.keys()):
state = converted_state_dict[key]
olmo_core_state = model_state_dict[key]
if isinstance(olmo_core_state, DTensor):
olmo_core_state = distribute_tensor(
state, olmo_core_state.device_mesh, olmo_core_state.placements
)
else:
olmo_core_state = state
model_state_dict[key] = olmo_core_state
if work_dir:
clear_directory(work_dir)
[docs]
@beta_feature
def save_hf_model(
save_dir: PathOrStr,
model_state_dict: Dict[str, Any],
model: Transformer,
huggingface_tokenizer: Optional[AutoTokenizer] = None,
*,
dtype: Optional[DType] = None,
vocab_size: Optional[int] = None,
process_group: Optional[dist.ProcessGroup] = None,
work_dir: Optional[PathOrStr] = None,
save_overwrite: bool = False,
):
"""
Saves an OLMo Core model state dict in Hugging Face transformers format.
:param save_dir: Directory in which to save model.
:param model_state_dict: The OLMo Core model state dict being saved in HF format.
:param dtype: The torch dtype that model weights should be saved as.
:param vocab_size: The size of the vocab, defaults to the number of embeddings in the OLMo Core model.
:param process_group: The process group to use for distributed communication.
:param work_dir: A local directory that can be used for holding temporary state. Required when
downloading a model from a cloud directory.
:param save_overwrite: Overwrite existing files in ``save_dir``.
"""
hf_config = get_hf_config(model)
model_state_dict = {key: get_full_tensor(state) for key, state in model_state_dict.items()}
if dtype is not None:
model_state_dict = {
key: state.to(dtype=dtype.as_pt()) for key, state in model_state_dict.items()
}
hf_state_dict: Dict[str, torch.Tensor] = convert_state_to_hf(hf_config, model_state_dict)
# model.save_pretrained fails says `tensor.reshape()` should be used instead of `tensor.view()`
# if we do not make the state contiguous. Unfortunately this is bad for perf.
hf_state_dict = {key: state.contiguous() for key, state in hf_state_dict.items()}
with init_empty_weights():
log.info("Initializing HF model with empty weights...")
hf_model = AutoModelForCausalLM.from_config(hf_config)
del hf_config
hf_model.load_state_dict(hf_state_dict, assign=True)
hf_model.config.vocab_size = vocab_size or model.vocab_size
hf_model.resize_token_embeddings(hf_model.config.vocab_size)
hf_model.generation_config.do_sample = True
if huggingface_tokenizer is not None:
hf_model.generation_config.eos_token_id = huggingface_tokenizer.convert_tokens_to_ids(
["<|im_end|>", "<|endoftext|>"]
)
hf_model.generation_config.pad_token = huggingface_tokenizer.pad_token_id
if get_fs_local_rank(process_group) == 0:
if is_url(save_dir):
assert work_dir is not None
hf_model.save_pretrained(work_dir)
copy_dir(work_dir, save_dir, save_overwrite=save_overwrite)
else:
target = Path(save_dir)
if target.is_dir() and not save_overwrite:
raise FileExistsError(target)
target.parent.mkdir(exist_ok=True, parents=True)
hf_model.save_pretrained(target)
[docs]
@beta_feature
def save_hf_hybrid_model(
save_dir: PathOrStr,
model_state_dict: Dict[str, Any],
model: Transformer,
*,
dtype: Optional[DType] = None,
vocab_size: Optional[int] = None,
max_sequence_length: int = 65536,
) -> None:
"""
Save a hybrid (GDN + attention) model as ``config.json`` + ``model.safetensors``.
Unlike :func:`save_hf_model`, this writes files directly to avoid a hard dependency
on a specific ``transformers`` version.
:param save_dir: Directory in which to save the model.
:param model_state_dict: The OLMo-core model state dict.
:param model: The OLMo-core hybrid transformer model.
:param dtype: Optional dtype to cast weights to.
:param vocab_size: If set, truncate embeddings/lm_head to this size.
:param max_sequence_length: Maximum sequence length for ``max_position_embeddings``.
"""
import json
from safetensors.torch import save_file
layer_types = get_hybrid_layer_types(model)
hf_config = get_hybrid_hf_config(model, layer_types, max_seq_len=max_sequence_length)
model_state_dict = {key: get_full_tensor(state) for key, state in model_state_dict.items()}
hf_state = convert_hybrid_state_to_hf(model_state_dict, layer_types)
if dtype is not None:
hf_state = {
k: v.to(dtype.as_pt()) if torch.is_tensor(v) else v for k, v in hf_state.items()
}
if vocab_size is not None:
hf_config["vocab_size"] = vocab_size
if "model.embed_tokens.weight" in hf_state:
hf_state["model.embed_tokens.weight"] = hf_state["model.embed_tokens.weight"][
:vocab_size
]
if "lm_head.weight" in hf_state:
hf_state["lm_head.weight"] = hf_state["lm_head.weight"][:vocab_size]
log.info(f"Converted state dict has {len(hf_state)} keys")
save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True)
config_path = save_path / "config.json"
with open(config_path, "w") as f:
json.dump(hf_config, f, indent=2)
log.info(f"Saved config to {config_path}")
save_file(hf_state, save_path / "model.safetensors")
log.info(f"Saved weights to {save_path / 'model.safetensors'}")