Source code for olmo_core.train.checkpoint

import json
import logging
import os
import re
import tempfile
from concurrent.futures import Future
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Any, ClassVar, Dict, Generator, Optional, Tuple, Union

import torch
import torch.distributed as dist
from cached_path import cached_path
from torch.distributed.checkpoint.metadata import Metadata

from ..aliases import PathOrStr
from ..config import Config
from ..distributed.checkpoint import (
    async_save_state_dict,
    get_checkpoint_metadata,
    load_state_dict,
    save_state_dict,
)
from ..distributed.utils import (
    barrier,
    broadcast_object,
    get_fs_local_rank,
    get_rank,
    is_distributed,
)
from ..exceptions import OLMoConfigurationError
from ..io import (
    clear_directory,
    dir_is_empty,
    file_exists,
    is_url,
    join_path,
    list_directory,
    normalize_path,
    upload,
)
from ..utils import wait_for
from ..version import VERSION
from .train_module import TrainModule

log = logging.getLogger(__name__)


[docs] @dataclass class CheckpointerConfig(Config): """ A configuration class for building :class:`Checkpointer` instances. """ work_dir: Optional[str] = None save_overwrite: Optional[bool] = None pre_download: bool = False save_thread_count: Optional[int] = None load_thread_count: Optional[int] = None # save_process_count: Optional[int] = None throttle_uploads: bool = False def build(self, process_group: Optional[dist.ProcessGroup] = None, **kwargs) -> "Checkpointer": kwargs = {**self.as_dict(exclude_none=True, recurse=False), **kwargs} work_dir = kwargs.pop("work_dir", None) if work_dir is None: raise OLMoConfigurationError("'work_dir' must be provided to build a Checkpointer") return Checkpointer(work_dir=Path(work_dir), process_group=process_group, **kwargs)
@dataclass class CheckpointMetadata(Config): ephemeral: Optional[bool] = None version: str = VERSION
[docs] @dataclass class Checkpointer: """ Trainer checkpointer. """ METADATA_FNAME: ClassVar[str] = ".metadata.json" CHECKPOINT_DIR: ClassVar[str] = "step{step}" FS_TIMEOUT: ClassVar[float] = 120.0 work_dir: Path save_overwrite: bool = False pre_download: bool = False process_group: Optional[dist.ProcessGroup] = None save_thread_count: Optional[int] = None load_thread_count: Optional[int] = None # save_process_count: Optional[int] = None # TODO: leads to some MP issues, needs more investigating. throttle_uploads: bool = False def __post_init__(self): self.work_dir = Path(self.work_dir) if get_fs_local_rank() == 0: self.work_dir.mkdir(exist_ok=True, parents=True)
[docs] def save( self, dir: PathOrStr, train_module: TrainModule, train_state: Dict[str, Any], ephemeral: bool = False, ): """ Save model, optim, and other training state to a local or remote directory. """ if torch.cuda.is_available(): torch.cuda.synchronize() dir = normalize_path(dir) with self._temporary_wd(dir) as wd: # Save trainer state. self._save_train_state(dir, wd, train_state) # Save model and optim state. train_module_dir = f"{dir}/model_and_optim" if is_url(dir) else wd / "model_and_optim" save_state_dict( train_module_dir, train_module.state_dict_to_save(), process_group=self.process_group, thread_count=self.save_thread_count, # process_count=self.save_process_count, throttle_uploads=self.throttle_uploads, enable_plan_caching=True, # NOTE: we've already checked and cleared the directory at this point so we can skip # the extra synchronization. _skip_prepare=True, ) self._save_metadata(dir, CheckpointMetadata(ephemeral=ephemeral))
[docs] def save_async( self, dir: PathOrStr, train_module: TrainModule, train_state: Dict[str, Any], ephemeral: bool = False, ) -> Future[None]: """ An async version of :meth:`save()`. """ if is_distributed() and self.process_group is None: raise OLMoConfigurationError( "a checkpointer process group is required for async checkpointing!" ) if torch.cuda.is_available(): torch.cuda.synchronize() dir = normalize_path(dir) with self._temporary_wd(dir) as wd: # Save trainer state. self._save_train_state(dir, wd, train_state) # Save model and optim state. train_module_dir = f"{dir}/model_and_optim" future = async_save_state_dict( train_module_dir, train_module.state_dict_to_save(), process_group=self.process_group, thread_count=self.save_thread_count, # process_count=self.save_process_count, throttle_uploads=self.throttle_uploads, enable_plan_caching=True, # NOTE: we've already checked and cleared the directory at this point so we can skip # the extra synchronization. _skip_prepare=True, ) def done_callback(fut: Future): del fut self._save_metadata(dir, CheckpointMetadata(ephemeral=ephemeral)) # Upload metadata when everything else is done. future.add_done_callback(done_callback) return future
[docs] def load( self, dir: PathOrStr, train_module: TrainModule, *, load_trainer_state: Optional[bool] = None, load_optim_state: Optional[bool] = None, ) -> Optional[Dict[str, Any]]: """ Load model, optim, and other training state from a local or remote checkpoint directory created via :meth:`save()` or :meth:`save_async()`. """ dir = normalize_path(dir) # Maybe load trainer state. trainer_state: Optional[Dict[str, Any]] = None if load_trainer_state is not False: # Try loading the given rank's state first, then fall back to rank 0 train state if it # doesn't exist, which can happen when we're restoring a checkpoint with a different world size. for path in (f"{dir}/train/rank{get_rank()}.pt", f"{dir}/train/rank0.pt"): try: trainer_state = torch.load(cached_path(path, quiet=True), weights_only=False) break except FileNotFoundError: pass if load_trainer_state is True and trainer_state is None: raise FileNotFoundError(f"Missing trainer state in checkpoint dir '{dir}'") # Load train module state. train_module_dir = f"{dir}/model_and_optim" metadata: Optional[Metadata] = None if get_rank(self.process_group) == 0: try: metadata = get_checkpoint_metadata(train_module_dir) except FileNotFoundError: # Try base directory, which could be the case if user is trying to load model weights # (possibly with optimizer state), and not an actual train checkpoint. if trainer_state is None: metadata = get_checkpoint_metadata(dir) train_module_dir = dir else: raise train_module_dir = broadcast_object(train_module_dir) if metadata is None: metadata = get_checkpoint_metadata(train_module_dir) state_dict = train_module.state_dict_to_load(metadata, optim=load_optim_state) load_state_dict( train_module_dir, state_dict, process_group=self.process_group, pre_download=is_url(dir) and self.pre_download, work_dir=self.work_dir, thread_count=self.load_thread_count, ) train_module.load_state_dict(state_dict) return trainer_state
[docs] def write_file(self, dir: PathOrStr, fname: str, contents: Union[str, bytes]) -> PathOrStr: """ Write something to a file in a local or remote directory. :param dir: The path/URL of the directory to write the file to. :param fname: The name of the file to write, relative to ``dir``. :param contents: The contents of the file to write. :returns: The path/URL of the file. """ dir = normalize_path(dir) fname = normalize_path(fname) if not is_url(dir): Path(dir).mkdir(exist_ok=True, parents=True) mode = "wb" if isinstance(contents, bytes) else "wt" tmp_file = tempfile.NamedTemporaryFile( mode=mode, delete=False, dir=None if is_url(dir) else dir ) tmp_path = Path(tmp_file.name) try: tmp_file.write(contents) # Ensure all data is written to disk. tmp_file.flush() if hasattr(os, "fdatasync"): # only available on linux os.fdatasync(tmp_file) # type: ignore tmp_file.close() target: PathOrStr if is_url(dir): target = f"{dir}/{fname}" upload(tmp_path, target, save_overwrite=self.save_overwrite) else: target = Path(dir) / fname if target.is_file() and not self.save_overwrite: raise FileExistsError(target) target.parent.mkdir(exist_ok=True, parents=True) tmp_path.replace(target) return target finally: tmp_file.close() tmp_path.unlink(missing_ok=True)
@classmethod def checkpoint_dirname(cls, step: int) -> str: return cls.CHECKPOINT_DIR.format(step=step)
[docs] @classmethod def dir_is_checkpoint(cls, dir: PathOrStr) -> bool: """ Check if a directory is a checkpoint directory. """ dir = normalize_path(dir) if file_exists(f"{dir}/.metadata"): # just model (and maybe optim state), no trainer state return True paths_to_check = [ f"{dir}/train/rank0.pt", f"{dir}/model_and_optim/.metadata", f"{dir}/{cls.METADATA_FNAME}", ] for path in paths_to_check: if not file_exists(path): return False return True
[docs] @classmethod def find_checkpoints( cls, dir: PathOrStr, ephemeral: Optional[bool] = None ) -> Generator[Tuple[int, str], None, None]: """ Find checkpoints within a directory. """ dir = normalize_path(dir) for path in list_directory(dir): name = os.path.basename(path) if (m := re.match("^" + cls.CHECKPOINT_DIR.format(step=r"(\d+)$"), name)) is not None: step = int(m.group(1)) # Make sure the directory is a valid checkpoint dir. if not cls.dir_is_checkpoint(path): continue # Filter out based on ephemeral flag. if ephemeral is not None: metadata_path = cached_path(join_path(path, cls.METADATA_FNAME), quiet=True) metadata = CheckpointMetadata.from_file(metadata_path) # Assume not ephemeral for backwards compat. if metadata.ephemeral is None: metadata.ephemeral = False if metadata.ephemeral != ephemeral: continue yield step, path
[docs] @classmethod def contains_checkpoint(cls, dir: PathOrStr) -> bool: """ Check if a directory is a checkpoint directory or contains a child checkpoint directory. """ if cls.dir_is_checkpoint(dir): return True try: next(cls.find_checkpoints(dir)) return True except (StopIteration, FileNotFoundError): return False
[docs] @classmethod def latest_checkpoint(cls, dir: PathOrStr) -> str: """ Find the latest checkpoint in a directory of checkpoints. :raises FileNotFoundError: If no checkpoints are found. """ dir = normalize_path(dir) latest_step: Optional[int] = None latest_checkpoint: Optional[str] = None for step, path in cls.find_checkpoints(dir): if latest_step is None or step > latest_step: latest_step = step latest_checkpoint = path if latest_checkpoint is None: raise FileNotFoundError(f"No checkpoints found in '{dir}'") else: return latest_checkpoint
def _save_train_state(self, dir: PathOrStr, wd: Path, train_state: Dict[str, Any]): train_dir = wd / "train" # NOTE: if 'dir' is a URL, the 'wd' will be a different temp dir for each rank. if is_url(dir) or get_fs_local_rank() == 0: train_dir.mkdir(exist_ok=True, parents=True) wait_for( train_dir.exists, description=f"waiting for '{train_dir}' to be created...", timeout=self.FS_TIMEOUT, ) torch.save(train_state, train_dir / f"rank{get_rank()}.pt") def _save_metadata(self, dir: PathOrStr, metadata: CheckpointMetadata): if get_rank() == 0: self.write_file(dir, self.METADATA_FNAME, json.dumps(metadata.as_dict(json_safe=True))) def _prepare_dir(self, dir: PathOrStr, ensure_exists: bool = True) -> str: dir = normalize_path(dir) # Make sure checkpoint directory is empty. if self.save_overwrite: if get_fs_local_rank() == 0: clear_directory(dir) elif not dir_is_empty(dir): raise FileExistsError(dir) # NOTE: We need a barrier here in both cases. # 1. If 'self.save_overwrite' then we clear the directory, and anytime we clear a directory in # preparation to use it we should have a barrier right after, otherwise one rank might get # ahead and write something to the directory prematurely, which then gets removed by the call # to `clear_directory()`. # 2. And otherwise we are checking if the directory is empty and raising an error if it's not, # so we need to make sure all ranks are synchronized on that check before they can proceed # to write to the directory. barrier() if ensure_exists and not is_url(dir): if get_fs_local_rank() == 0: Path(dir).mkdir(exist_ok=True, parents=True) # Ensure the dir exists for all ranks before continuing. This might take a second if we're # saving to an NFS drive or something like that. wait_for( Path(dir).exists, description=f"waiting on '{dir}' to be created...", timeout=self.FS_TIMEOUT, ) return dir def _get_tmp_dir(self, dir: PathOrStr) -> Path: # Prepare temporary directory. tmp_dir: Path if is_url(dir): tmp_dir = Path(tempfile.mkdtemp(dir=str(self.work_dir))) else: tmp_dir = Path(dir).with_name(Path(dir).name + "-tmp") if get_fs_local_rank() == 0: clear_directory(tmp_dir) tmp_dir.mkdir(exist_ok=True, parents=True) # NOTE: anytime we clear a directory in preparation to use it we should have a barrier # right after, otherwise one rank might get ahead and write something to the directory # prematurely, which then gets removed by the call to `clear_directory()`. barrier() # In the cases where we're using a shared NFS drive between ranks to save checkpoints, # creating the temp directory from rank 0 might not be immediately # realized in the file systems of the other ranks. # So we wait here across all ranks until that tmp checkpoint directory is visible. wait_for( lambda: tmp_dir.exists(), "Waiting for checkpoint directory", timeout=self.FS_TIMEOUT, ) return tmp_dir def _teardown_tmp_dir(self, dir: PathOrStr, tmp_dir: Path): if not is_url(dir): # NOTE: When dir is not a URL, tmp dir is shared among ranks so we need a barrier before # we tear it down to avoid overwriting the work of other ranks. barrier() # Replace the temporary directory with the actual checkpoint directory. if get_fs_local_rank() == 0: # Replace temp directory with target checkpoint directory. try: tmp_dir.replace(str(dir)) except FileNotFoundError: # Caught when another (file-system) local rank 0 has already replaced the tmp directory. # This can happen when nodes are saving to a common NFS drive but otherwise have distinct # file-systems. if not Path(dir).exists(): raise # In the cases where we're using a shared NFS drive between ranks to save checkpoints, # replacing the temp directory with the final directory from rank 0 might not be immediately # realized in the file systems of the other ranks. # So we wait here across all ranks until that final checkpoint directory is visible. wait_for( lambda: Path(dir).exists(), f"waiting for checkpoint directory '{dir}' from rank {get_rank()}", timeout=self.FS_TIMEOUT, ) else: # NOTE: When dir is a URL, each rank will have its own tmp dir so synchronizing with a # barrier isn't necessary. # Upload files to final location. for path in tmp_dir.glob("**/*"): if not path.is_file(): continue upload( path, f"{dir}/{path.relative_to(tmp_dir)}", save_overwrite=self.save_overwrite, ) # Then remove the temp dir. clear_directory(tmp_dir) @contextmanager def _temporary_wd(self, dir: PathOrStr) -> Generator[Path, None, None]: # No need to mkdir here since we'll directly replace the temporary directory with # this directory below. dir = self._prepare_dir(dir, ensure_exists=False) tmp_dir = self._get_tmp_dir(dir) yield tmp_dir self._teardown_tmp_dir(dir, tmp_dir)