import dataclasses
import functools as ft
import logging
import math
import typing
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional
import torch
import torch.nn as nn
from safetensors.torch import save_file
from torch.distributed.tensor import DTensor
from olmo_core.distributed.checkpoint import save_state_dict
from olmo_core.distributed.utils import get_full_tensor, get_local_tensor, get_rank
from olmo_core.utils import gc_cuda
from ..common import MetricMergeStrategy, ReduceType
from ..train_module import TransformerTrainModule
from .callback import Callback
log = logging.getLogger(__name__)
[docs]
@dataclass
class GAPMonitorCallback(Callback):
"""
Gradient, activation, and parameter (GAP) monitoring callback.
This callback logs fine-grained statistics on all gradients, activations, and parameters.
It can also dump raw gradient tensors to disk for offline analysis. Set ``dump_gradients=True``
and configure the ``dump_gradients_*`` fields to control when and how gradients are saved.
"""
enabled: bool = True
"""Master switch. When ``False``, all monitoring and gradient dumping is disabled."""
monitor: Optional[bool] = None
"""Whether to run GAP monitoring (forward/backward hooks, per-tensor stats).
Only takes effect when ``enabled=True``. Defaults to ``True`` when ``enabled=True``."""
interval: int = 1
"""How often (in steps) to measure statistics. Default is every step."""
dump_gradients: Optional[bool] = None
"""Whether to dump raw gradient tensors to disk for offline analysis.
Only takes effect when ``enabled=True``. Defaults to ``False`` when ``enabled=True``."""
dump_gradients_start_step: int = 0
"""Step at which to begin dumping gradients. Inclusive."""
dump_gradients_end_step: Optional[int] = None
"""Step at which to stop dumping gradients. Inclusive. If ``None``, runs until training ends."""
dump_gradients_step_interval: int = 1
"""How often (in steps) to dump gradients. Must be positive."""
dump_gradients_save_first_n: Optional[int] = None
"""If set, gather the full gradient to rank 0 and save only the first N elements of each
dimension, storing as a single safetensors file. If ``None``, saves the full distributed
gradient via distributed checkpoint. Must be positive if set."""
_handles: Optional[list] = dataclasses.field(default=None, repr=False)
_local_batch_size_instances: int = dataclasses.field(default=1, repr=False)
_dry_run_complete: bool = dataclasses.field(default=False, repr=False)
def __post_init__(self):
# Validate: sub-flags must not be True when master switch is off.
if not self.enabled and (self.monitor or self.dump_gradients):
raise ValueError(
"Cannot set monitor or dump_gradients to True when enabled=False. "
"Set enabled=True to use this callback."
)
# Resolve None defaults.
if self.enabled:
self.monitor = self.monitor if self.monitor is not None else True
self.dump_gradients = self.dump_gradients if self.dump_gradients is not None else False
# Validate dump_gradients_* params only when gradient dumping is active.
if self.enabled and self.dump_gradients:
if self.dump_gradients_start_step < 0:
raise ValueError(
f"dump_gradients_start_step must be non-negative, got {self.dump_gradients_start_step}"
)
if self.dump_gradients_step_interval <= 0:
raise ValueError(
f"dump_gradients_step_interval must be positive, got {self.dump_gradients_step_interval}"
)
if (
self.dump_gradients_save_first_n is not None
and self.dump_gradients_save_first_n <= 0
):
raise ValueError(
f"dump_gradients_save_first_n must be positive, got {self.dump_gradients_save_first_n}"
)
# Validate: enabled but doing nothing.
if self.enabled and not self.monitor and not self.dump_gradients:
raise ValueError(
"enabled=True but both monitor and dump_gradients are False. "
"Set at least one to True, or set enabled=False."
)
def post_attach(self):
if not self.enabled:
return
if not isinstance(self.trainer.train_module, TransformerTrainModule):
raise ValueError(f"{type(self).__name__} only works with the TransformerTrainModule.")
def pre_train(self):
if not self.enabled or not self.monitor:
return
assert isinstance(self.trainer.train_module, TransformerTrainModule)
self._reset()
handles: List[torch.utils.hooks.RemovableHandle] = []
for n, m in self.trainer.train_module.model.named_modules():
m = typing.cast(nn.Module, m)
if n == "":
continue
# Register forward hook to monitor activations.
h = m.register_forward_hook(ft.partial(self.forward_hook, module_name=n))
handles.append(h)
# Register backward pre-hook to monitor gradients wrt activations.
h = m.register_full_backward_pre_hook(ft.partial(self.backward_hook, module_name=n))
handles.append(h)
self._handles = handles # type: ignore[assignment]
def pre_step(self, batch: Dict[str, Any]):
if not self.enabled or not self.monitor:
return
self._dry_run_complete = True
self._local_batch_size_instances = batch["input_ids"].shape[0]
def pre_optim_step(self):
if not self.enabled:
return
if self.monitor:
assert isinstance(self.trainer.train_module, TransformerTrainModule)
for n, p in self.trainer.train_module.model.named_parameters():
self.record_tensor_stats(n, p, "param")
if p.grad is not None:
self.record_tensor_stats(n, p.grad, "grad")
if self.dump_gradients:
self._dump_gradients()
@torch._dynamo.disable()
def forward_hook(self, module: nn.Module, args, output, module_name: str):
del module, args
if not self.enabled or not self.monitor:
return
if isinstance(output, tuple):
output = output[0]
if isinstance(output, torch.Tensor):
self.record_tensor_stats(module_name, output, "activation")
elif output is not None:
raise RuntimeError(f"unsupported output type {type(output)} for module '{module_name}'")
@torch._dynamo.disable()
def backward_hook(self, module: nn.Module, grad_output, module_name: str):
del module
if not self.enabled or not self.monitor:
return
if isinstance(grad_output, tuple):
grad_output = grad_output[0]
if isinstance(grad_output, torch.Tensor):
self.record_tensor_stats(module_name, grad_output, "activation_grad")
elif grad_output is not None:
raise RuntimeError(
f"unsupported grad_output type {type(grad_output)} for module '{module_name}'"
)
@torch.no_grad()
def record_tensor_stats(
self,
name: str,
tensor: torch.Tensor,
kind: Literal["grad", "activation", "activation_grad", "param"],
):
if self.step % self.interval != 0:
return
if tensor.numel() <= 1:
return
tensor = tensor.detach()
prefix = f"gap/{kind}s"
if kind in ("activation", "activation_grad"):
if tensor.ndim <= 1:
# No point in computing stats for 0-dim or 1-dim activations (like the loss).
return
# For activations/output-grads we'll compute the local stats *per instance* and then average them
# across the global batch.
# Technically it might be better to compute global stats directly, but this way is
# cheaper, much simpler, and probably good enough.
tensor = get_local_tensor(tensor)
tensor = tensor.reshape(tensor.shape[0], -1)
max_ = tensor.abs().max()
var, mean = var_mean(tensor, dim=-1)
# NOTE: to handle gradient accumulation we divide by local batch size (in instances),
# which is recorded in `self.pre_step()`, as opposed to micro-batch size, and then
# we use the "sum" merge strategy.
var = var.float().sum() / self._local_batch_size_instances
mean = mean.float().sum() / self._local_batch_size_instances
if self._dry_run_complete:
self.trainer.record_metric(
f"{prefix}/{name}/max",
max_,
reduce_type=ReduceType.max,
merge_strategy=MetricMergeStrategy.max,
)
self.trainer.record_metric(
f"{prefix}/{name}/mean",
mean,
reduce_type=ReduceType.mean,
merge_strategy=MetricMergeStrategy.sum,
)
self.trainer.record_metric(
f"{prefix}/{name}/var",
var,
reduce_type=ReduceType.mean,
merge_strategy=MetricMergeStrategy.sum,
)
else:
var, mean = var_mean(tensor)
local_tensor = get_local_tensor(tensor)
if local_tensor.numel() > 0:
local_max = local_tensor.abs().max()
else:
# Use 0.0 as sentinel value for empty tensors in max reduction.
# Since we're taking abs(), all actual values are >= 0, so 0.0
# won't affect the max reduction when other processes have non-empty tensors.
local_max = torch.zeros([], device=tensor.device, dtype=tensor.dtype)
if self._dry_run_complete:
self.trainer.record_metric(
f"{prefix}/{name}/max", local_max, reduce_type=ReduceType.max
)
self.trainer.record_metric(f"{prefix}/{name}/mean", mean, reduce_type=None)
self.trainer.record_metric(f"{prefix}/{name}/var", var, reduce_type=None)
def _dump_gradients(self):
"""Save gradient tensors to disk based on dump_gradients_* configuration."""
if self.step < self.dump_gradients_start_step:
return
if self.dump_gradients_end_step is not None and self.step > self.dump_gradients_end_step:
return
if (self.step - self.dump_gradients_start_step) % self.dump_gradients_step_interval != 0:
return
step_dir = self.trainer.work_dir / "gradients" / f"step{self.step}"
step_dir.mkdir(exist_ok=True, parents=True)
model = getattr(self.trainer.train_module, "model")
if self.dump_gradients_save_first_n is None:
# Save full gradients using distributed checkpoint
full_grads_dir = step_dir / "full_gradients"
grad_dict = {}
for name, p in model.named_parameters():
if p.grad is not None:
grad_dict[name] = p.grad.detach()
log.info(f"Saving {len(grad_dict)} gradient tensors for step {self.step}...")
save_state_dict(
full_grads_dir,
grad_dict,
save_overwrite=True,
)
log.info(f"Saved full gradients for step {self.step} to '{full_grads_dir}'")
else:
sampled_gradients_dir = step_dir / "sampled_gradients"
if get_rank() == 0:
sampled_gradients_dir.mkdir(exist_ok=True, parents=True)
for name, p in model.named_parameters():
if p.grad is not None:
full_grad = get_full_tensor(p.grad.detach())
if get_rank() == 0:
full_grad = full_grad.cpu()
if full_grad.ndim == 0:
sampled_grad = full_grad
else:
slices = []
for dim_size in full_grad.shape:
sampled_dim_size = min(self.dump_gradients_save_first_n, dim_size)
slices.append(slice(0, sampled_dim_size))
sampled_grad = full_grad[tuple(slices)].contiguous()
sampled_filepath = sampled_gradients_dir / f"{name}.safetensors"
save_file({"gradient": sampled_grad}, str(sampled_filepath))
log.info(
f"Saved sampled gradient '{name}' with shape {tuple(sampled_grad.shape)} "
f"to '{sampled_filepath}'"
)
del full_grad
if get_rank() == 0:
log.info(f"Saved sampled gradients for step {self.step} to {sampled_gradients_dir}")
if get_rank() == 0:
rel_step_dir = step_dir.relative_to(self.trainer.work_dir)
target_dir = self.trainer.persist_working_subdir(rel_step_dir)
log.info(f"Gradients for step {self.step} saved to '{target_dir}'")
gc_cuda()
def close(self):
self._reset()
def _reset(self):
self._dry_run_complete = False
if self._handles is not None:
for h in self._handles:
h.remove()
self._handles = None
def var_mean(tensor: torch.Tensor, dim: Optional[int] = None) -> tuple[torch.Tensor, torch.Tensor]:
if not isinstance(tensor, DTensor):
return torch.var_mean(tensor, dim=dim)
else:
# NOTE: 'torch.var_mean()' not implemented for DTensor.
numel = tensor.numel() if dim is None else tensor.size(dim)
mean = get_full_tensor(tensor.mean(dim=dim))
stdd = get_full_tensor(torch.linalg.vector_norm(tensor - mean, dim=dim)) / math.sqrt(
max(1, numel - 1)
)
var = stdd**2
return var, mean