mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 16:37:46 +00:00
695 lines
23 KiB
Python
695 lines
23 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
|
# This file from the xFormers repo is just a example of how to implement
|
|
# probing of the activations of a model, without changing anything.
|
|
# By default, the linear inputs/outputs/gradients are logged, as well as
|
|
# the attention logits+entropy. It is possible to log an additional tensor, eg:
|
|
# x = log_stats(x, "name")
|
|
#
|
|
# Known limitations:
|
|
# * Only a subset of the attention biases is supported
|
|
# * Torch-compile is disabled automatically when this is enabled
|
|
# * Only tested with bf16/f16/f32 datatypes
|
|
|
|
import contextlib
|
|
import functools
|
|
import json
|
|
import math
|
|
import os
|
|
import uuid
|
|
from collections import defaultdict
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|
CheckpointImpl,
|
|
checkpoint_wrapper,
|
|
)
|
|
from torch.fx.operator_schemas import normalize_function
|
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
from torch.utils._pytree import tree_map
|
|
from torch.utils.module_tracker import ModuleTracker
|
|
from xformers.ops import fmha
|
|
|
|
|
|
@torch.library.custom_op("torchprobe::log", mutates_args=(), device_types=None)
|
|
def _log(x: torch.Tensor, name: str, uid: str) -> None:
|
|
pass
|
|
|
|
|
|
@_log.register_fake
|
|
def _log_fake(x: torch.Tensor, name: str, uid: str) -> None:
|
|
pass
|
|
|
|
|
|
class _LogStats(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x: torch.Tensor, name: str):
|
|
uid = str(uuid.uuid4())
|
|
torch.ops.torchprobe.log(x, name, uid)
|
|
ctx.name = name
|
|
ctx.uid = uid
|
|
return x
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad: torch.Tensor):
|
|
torch.ops.torchprobe.log(grad, f"{ctx.name}.g", ctx.uid)
|
|
return grad, None
|
|
|
|
|
|
_PROBING_ENABLED = False
|
|
|
|
|
|
def log_stats(x: torch.Tensor, name: str) -> torch.Tensor:
|
|
if not _PROBING_ENABLED:
|
|
return x
|
|
return _LogStats.apply(x, name)
|
|
|
|
|
|
QUANTILES = [
|
|
0.0000001,
|
|
0.000001,
|
|
0.00001,
|
|
0.0001,
|
|
0.001,
|
|
0.01,
|
|
0.05,
|
|
0.1,
|
|
0.3,
|
|
0.5,
|
|
0.7,
|
|
0.9,
|
|
0.95,
|
|
0.99,
|
|
0.999,
|
|
0.9999,
|
|
0.99999,
|
|
0.999999,
|
|
0.9999999,
|
|
]
|
|
|
|
|
|
@functools.cache
|
|
def _get_quantiles(device: torch.device, dtype) -> torch.Tensor:
|
|
return torch.tensor(QUANTILES, device=device, dtype=dtype)
|
|
|
|
|
|
def _get_stats(x_: torch.Tensor, remove_inf=False) -> Dict[str, Any]:
|
|
if x_.dtype not in [torch.float, torch.double, torch.float16, torch.bfloat16]:
|
|
return {}
|
|
x = x_.flatten()
|
|
if remove_inf:
|
|
x = x[x.abs() < float("inf")]
|
|
if x.dtype is not torch.double:
|
|
x = x.float()
|
|
xabs = x.abs()
|
|
quantiles = _get_quantiles(x.device, x.dtype)
|
|
mean = x.mean()
|
|
std = x.std()
|
|
return {
|
|
"shape": tuple(x_.shape),
|
|
"mean": mean,
|
|
"std": std,
|
|
"skew": (((x - mean) / std) ** 3).double().mean(),
|
|
"kurtosis": (((x - mean) / std) ** 4).double().mean(),
|
|
"abs.mean": xabs.mean(),
|
|
"max": x.max(),
|
|
"min": x.min(),
|
|
# Note: `quantile` takes at most 2**24 elements, see
|
|
# https://github.com/pytorch/pytorch/issues/64947
|
|
"quantiles": torch.quantile(x[: 2**24], quantiles),
|
|
}
|
|
|
|
|
|
def _mask_attn_causal_inplace(logits: torch.Tensor, q_idx, q_len, kv_len) -> None:
|
|
assert logits.ndim == 4
|
|
logits[:, :, :, q_idx + kv_len - q_len + 1 :] = -math.inf
|
|
|
|
|
|
def _mask_attn_logits(
|
|
logits: torch.Tensor,
|
|
q_idx: List[int],
|
|
*,
|
|
causal: bool,
|
|
cu_seqlens_q: Optional[torch.Tensor] = None,
|
|
cu_seqlens_k: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
assert logits.dtype is torch.float32
|
|
# Handle BlockDiagonalMask
|
|
if cu_seqlens_q is not None:
|
|
assert cu_seqlens_k is not None
|
|
# Expect BHMqMkv
|
|
assert logits.ndim == 4, logits.shape
|
|
qs = cu_seqlens_q.tolist()
|
|
ks = cu_seqlens_k.tolist()
|
|
q_batchid = []
|
|
k_batchid = [-2] * logits.shape[-1]
|
|
q_idx_i = 0
|
|
for bid, (q0, q1, k0, k1) in enumerate(zip(qs, qs[1:], ks, ks[1:])):
|
|
for k in range(k0, k1):
|
|
k_batchid[k] = bid
|
|
while q_idx_i < len(q_idx) and q_idx[q_idx_i] < q1:
|
|
q_batchid.append(bid)
|
|
if causal:
|
|
_mask_attn_causal_inplace(
|
|
logits[:, :, q_idx_i : q_idx_i + 1, k0:k1],
|
|
q_idx[q_idx_i] - q0,
|
|
q1 - q0,
|
|
k1 - k0,
|
|
)
|
|
q_idx_i += 1
|
|
mask_out = (
|
|
torch.tensor(q_batchid, device=logits.device)[None, None, :, None]
|
|
!= torch.tensor(k_batchid, device=logits.device)[None, None, None, :]
|
|
)
|
|
logits[mask_out.expand_as(logits)] = -math.inf
|
|
assert q_idx_i == len(q_idx)
|
|
elif causal:
|
|
for q_idx_i in range(len(q_idx)):
|
|
_mask_attn_causal_inplace(
|
|
logits[:, :, q_idx_i : q_idx_i + 1, :],
|
|
q_idx[q_idx_i],
|
|
logits.shape[2],
|
|
logits.shape[3],
|
|
)
|
|
return logits
|
|
|
|
|
|
def _attn_queries_subset(num_queries: int) -> List[int]:
|
|
return list(range(0, num_queries, max(1, num_queries // 128)))
|
|
|
|
|
|
@torch.no_grad()
|
|
def _compute_attn_stats_sdpa(
|
|
probe,
|
|
path: str,
|
|
# supports arguments both cudnn + flash backends
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attn_mask=None,
|
|
attn_bias=None,
|
|
dropout_p=0.0,
|
|
is_causal=False,
|
|
scale=None,
|
|
compute_log_sumexp=True,
|
|
return_debug_mask=False,
|
|
**kwargs,
|
|
):
|
|
if scale is None:
|
|
scale = 1 / (query.shape[-1] ** 0.5)
|
|
# Filter-out not supported cases
|
|
if attn_mask is not None or attn_bias is not None or dropout_p != 0.0 or kwargs:
|
|
probe.store[f"{path}::attn"] = {
|
|
"query.shape": tuple(query.shape),
|
|
"key.shape": tuple(key.shape),
|
|
"value.shape": tuple(value.shape),
|
|
"attn_mask": attn_mask.shape if attn_mask is not None else None,
|
|
"dropout_p": dropout_p,
|
|
"is_causal": is_causal,
|
|
"scale": scale,
|
|
"unk_kwargs": list(kwargs.keys()),
|
|
}
|
|
return
|
|
# Take a subset of the queries and compute the logits
|
|
query_s = _attn_queries_subset(query.shape[-2])
|
|
logits = query[:, :, query_s] @ key.transpose(-1, -2) * scale
|
|
logits = _mask_attn_logits(logits.float(), query_s, causal=is_causal)
|
|
p = logits.float().softmax(-1)
|
|
masked_logsoft = logits.log_softmax(-1).where(
|
|
(logits > -math.inf), torch.zeros_like(logits)
|
|
)
|
|
entropy = -(p * masked_logsoft).sum(-1)
|
|
probe.log_tensor(f"{path}::attn_entropy", entropy)
|
|
probe.log_tensor(f"{path}::attn_logits", logits, remove_inf=True)
|
|
|
|
|
|
@torch.no_grad()
|
|
def _compute_attn_stats_flash(
|
|
probe,
|
|
path: str,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
cu_seqlens_q: Optional[torch.Tensor],
|
|
cu_seqlens_k: Optional[torch.Tensor],
|
|
seqused_k: Optional[torch.Tensor],
|
|
max_seqlen_q: int,
|
|
max_seqlen_k: int,
|
|
p: float,
|
|
softmax_scale: float,
|
|
is_causal: bool,
|
|
window_left: int,
|
|
window_right: int,
|
|
return_softmax: bool,
|
|
block_tables: Optional[torch.Tensor],
|
|
unpadded_lse: bool = False,
|
|
) -> None:
|
|
# Filter-out not supported cases
|
|
if (
|
|
seqused_k is not None
|
|
or p != 0.0
|
|
or window_left >= 0
|
|
or window_right >= 0
|
|
or block_tables is not None
|
|
):
|
|
probe.store[f"{path}::attn"] = {
|
|
"query.shape": tuple(query.shape),
|
|
"key.shape": tuple(key.shape),
|
|
"value.shape": tuple(value.shape),
|
|
"op": "flash",
|
|
}
|
|
return
|
|
|
|
if cu_seqlens_q is not None:
|
|
assert query.ndim == 3, query.shape
|
|
query, key, value = query[None], key[None], value[None]
|
|
assert query.ndim == 4, query.shape
|
|
|
|
# Take a subset of the queries and compute the logits
|
|
query_s = _attn_queries_subset(query.shape[1])
|
|
logits = (
|
|
query[:, query_s].transpose(1, 2)
|
|
@ key.transpose(1, 2).transpose(-1, -2)
|
|
* softmax_scale
|
|
)
|
|
logits = _mask_attn_logits(
|
|
logits.float(),
|
|
query_s,
|
|
cu_seqlens_q=cu_seqlens_q,
|
|
cu_seqlens_k=cu_seqlens_k,
|
|
causal=is_causal,
|
|
)
|
|
p = logits.float().softmax(-1)
|
|
masked_logsoft = logits.log_softmax(-1).where(
|
|
(logits > -math.inf), torch.zeros_like(logits)
|
|
)
|
|
entropy = -(p * masked_logsoft).sum(-1)
|
|
probe.log_tensor(f"{path}::attn_entropy", entropy)
|
|
probe.log_tensor(f"{path}::attn_logits", logits, remove_inf=True)
|
|
|
|
|
|
def _tensors_to_python(x):
|
|
if not isinstance(x, torch.Tensor):
|
|
return x
|
|
return x.tolist()
|
|
|
|
|
|
# class syntax
|
|
class LinearBwType(Enum):
|
|
DW = 1
|
|
DX = 2
|
|
UNKNOWN = 3
|
|
|
|
|
|
class AutoProbeD(TorchDispatchMode):
|
|
def __init__(self, module: nn.Module, write_file: Optional[str] = None) -> None:
|
|
self.write_file = Path(write_file) if write_file is not None else None
|
|
self.write_tensors_tmpdir: Optional[Path] = None
|
|
self.compile_disabler = TorchCompileDisabler(module)
|
|
self.mod_tracker = ModuleTracker()
|
|
self.count_per_path: Dict[str, int] = defaultdict(int)
|
|
self.store: Dict[str, Dict[str, Any]] = {}
|
|
self.linear_data: Dict[str, Tuple[Any, Any, Any, Any, Any]] = {}
|
|
self.uid_to_path: Dict[str, str] = {}
|
|
self.metadata: Any = None
|
|
self.enabled = False
|
|
self.verbose = bool(int(os.environ.get("PROBE_VERBOSE", "0")))
|
|
|
|
def __enter__(self):
|
|
global _PROBING_ENABLED
|
|
assert not self.enabled, "Entered probe twice"
|
|
self.compile_disabler.__enter__()
|
|
self.mod_tracker.__enter__()
|
|
super().__enter__()
|
|
self.enabled = True
|
|
_PROBING_ENABLED = True
|
|
# self._setup_tensors_logging()
|
|
return self
|
|
|
|
def __exit__(self, *args) -> None:
|
|
global _PROBING_ENABLED
|
|
assert self.enabled, "Exiting probe without entering it"
|
|
super().__exit__(*args)
|
|
self.mod_tracker.__exit__(*args)
|
|
self.compile_disabler.__exit__(*args)
|
|
self._flush_and_clear()
|
|
_PROBING_ENABLED = False
|
|
self.enabled = False
|
|
|
|
def _setup_tensors_logging(self):
|
|
if self.write_file is not None:
|
|
self.write_file.parent.mkdir(exist_ok=True)
|
|
self.write_tensors_tmpdir = (
|
|
self.write_file.parent
|
|
/ f"{self.write_file.name}-tmp-{str(uuid.uuid4())[:8]}"
|
|
)
|
|
self.write_tensors_tmpdir.mkdir(exist_ok=True)
|
|
|
|
def _flush_and_clear(self) -> None:
|
|
if self.write_file is not None:
|
|
dump_data = tree_map(_tensors_to_python, self.store)
|
|
with self.write_file.open("a") as fd:
|
|
json.dump(
|
|
{
|
|
"data": dump_data,
|
|
"meta": self.metadata,
|
|
"version": 2,
|
|
"quantiles": QUANTILES,
|
|
},
|
|
fd,
|
|
)
|
|
fd.write("\n")
|
|
if self.write_tensors_tmpdir is not None:
|
|
assert self.write_file is not None
|
|
dump_dir = self.write_tensors_tmpdir.parent / f"{self.write_file.name}-dump"
|
|
dump_dir.mkdir(exist_ok=True)
|
|
dir_name = ""
|
|
if "it" in self.metadata:
|
|
dir_name = f"it{int(self.metadata['it']):010}"
|
|
if dir_name == "" or (dump_dir / dir_name).exists():
|
|
num_files = len(list(dump_dir.glob(f"{dir_name}v*")))
|
|
dir_name = f"{dir_name}v{num_files}"
|
|
dump_dir = dump_dir / dir_name
|
|
assert not dump_dir.exists()
|
|
self.write_tensors_tmpdir.rename(dump_dir)
|
|
self.write_tensors_tmpdir = None
|
|
self.store.clear()
|
|
self.count_per_path.clear()
|
|
self.uid_to_path.clear()
|
|
|
|
def _find_bw_path_and_type(
|
|
self, path: str, out: torch.Tensor, args
|
|
) -> Tuple[str, LinearBwType]:
|
|
"""
|
|
We are in the BW pass, and process a GEMM.
|
|
Let's figure out:
|
|
(1) The path for the FW pass (might differ in case of ModuleTracker bug)
|
|
(2) The type of BW pass (eg `dw` or `dx`)
|
|
"""
|
|
|
|
def _is_path_correct_dw(path: str) -> bool:
|
|
# dW.t = dY.t @ X
|
|
in_shape, w_shape, out_shape, input_sm, weight_sm = self.linear_data[path]
|
|
return out.shape == (w_shape[1], w_shape[0]) and torch.allclose(
|
|
input_sm, args[1][:4, :4]
|
|
)
|
|
|
|
def _is_path_correct_dx(path: str) -> bool:
|
|
# dX = dY @ W.t
|
|
in_shape, w_shape, out_shape, input_sm, weight_sm = self.linear_data[path]
|
|
return out.shape == in_shape and torch.allclose(weight_sm, args[1][:4, :4])
|
|
|
|
if path in self.linear_data:
|
|
if _is_path_correct_dw(path):
|
|
return path, LinearBwType.DW
|
|
if _is_path_correct_dx(path):
|
|
return path, LinearBwType.DX
|
|
for candidate_path in self.mod_tracker.parents:
|
|
if candidate_path not in self.linear_data:
|
|
continue
|
|
if _is_path_correct_dw(candidate_path):
|
|
return candidate_path, LinearBwType.DW
|
|
if _is_path_correct_dx(candidate_path):
|
|
return candidate_path, LinearBwType.DX
|
|
return path, LinearBwType.UNKNOWN
|
|
|
|
def log_tensor(self, name: str, x: torch.Tensor, **kwargs) -> None:
|
|
self.store[name] = _get_stats(x, **kwargs)
|
|
if self.write_tensors_tmpdir is not None:
|
|
name_safe = name.replace("::", "__").replace("/", "")
|
|
torch.save(x, self.write_tensors_tmpdir / f"{name_safe}.pkl")
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
kwargs = kwargs if kwargs else {}
|
|
path = None
|
|
# Find longest path
|
|
for p in self.mod_tracker.parents:
|
|
if p == "Global":
|
|
continue
|
|
if path is None or len(p) > len(path):
|
|
path = p
|
|
if path is None:
|
|
path = "Global"
|
|
path = path.replace("._checkpoint_wrapped_module", "")
|
|
out = func(*args, **kwargs)
|
|
|
|
# Handle linear layers
|
|
if func._overloadpacket in [torch.ops.aten.addmm, torch.ops.aten.mm]:
|
|
weight: torch.Tensor
|
|
input: torch.Tensor
|
|
if not self.mod_tracker.is_bw:
|
|
# (technically, weight is transposed)
|
|
if func._overloadpacket == torch.ops.aten.addmm:
|
|
_bias, input, weight = args[:3]
|
|
else:
|
|
assert func._overloadpacket == torch.ops.aten.mm
|
|
input, weight = args[:2]
|
|
self.log_tensor(f"{path}::in", input)
|
|
self.log_tensor(f"{path}::w", weight)
|
|
self.log_tensor(f"{path}::out", out)
|
|
self.linear_data[path] = (
|
|
input.shape,
|
|
weight.shape,
|
|
out.shape,
|
|
input[:4, :4].clone(),
|
|
weight[:4, :4].T.clone(),
|
|
)
|
|
elif func._overloadpacket == torch.ops.aten.mm:
|
|
# XXX: Try to find the actual path for the linear layer
|
|
# This is messed with with Francisco's FSDP sometimes
|
|
new_path, bwtype = self._find_bw_path_and_type(path, out, args)
|
|
if new_path != path:
|
|
if self.verbose:
|
|
print(f"E: Fixing path `{path}` -> `{new_path}")
|
|
path = new_path
|
|
|
|
if bwtype == LinearBwType.DW:
|
|
# dW.t = dY.t @ X
|
|
self.log_tensor(f"{path}::w.g", out)
|
|
elif bwtype == LinearBwType.DX:
|
|
# dX = dY @ W.t
|
|
self.log_tensor(f"{path}::in.g", out)
|
|
self.log_tensor(f"{path}::out.g", args[0])
|
|
elif func._overloadpacket in [
|
|
torch.ops.aten._scaled_dot_product_flash_attention,
|
|
torch.ops.aten._scaled_dot_product_cudnn_attention,
|
|
]:
|
|
_, kwargs = normalize_function(
|
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
|
)
|
|
_compute_attn_stats_sdpa(self, path, **kwargs)
|
|
elif func._overloadpacket == fmha.flash.FwOp.OPERATOR:
|
|
_, kwargs = normalize_function(
|
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
|
)
|
|
_compute_attn_stats_flash(self, path, **kwargs)
|
|
elif func._overloadpacket == torch.ops.torchprobe.log:
|
|
uid = args[2]
|
|
path = self.uid_to_path.setdefault(uid, path)
|
|
self.log_tensor(f"{path}::{args[1]}", args[0])
|
|
if self.verbose:
|
|
print(f"{'[BW]' if self.mod_tracker.is_bw else '[FW]'} `{path}`: {func}")
|
|
return out
|
|
|
|
|
|
def _find_all_submodules_compiled(out: List[nn.Module], module: nn.Module) -> None:
|
|
if module._compiled_call_impl is not None:
|
|
out.append(module)
|
|
for c in module.children():
|
|
_find_all_submodules_compiled(out, module=c)
|
|
|
|
|
|
class TorchCompileDisabler:
|
|
def __init__(self, module: nn.Module) -> None:
|
|
self.module = module
|
|
self.submodules_compiled: List[nn.Module] = []
|
|
self.compiled_call_impl: List[Any] = []
|
|
self.disable_compile = torch.compiler.disable()
|
|
torch._dynamo.config.raise_on_ctx_manager_usage = False # type: ignore
|
|
|
|
def __enter__(self) -> None:
|
|
# Remove all `_compiled_call_impl` attributes to effectively
|
|
# "undo" compilation
|
|
self.submodules_compiled.clear()
|
|
_find_all_submodules_compiled(self.submodules_compiled, self.module)
|
|
self.compiled_call_impl = [
|
|
m._compiled_call_impl for m in self.submodules_compiled
|
|
]
|
|
for m in self.submodules_compiled:
|
|
m._compiled_call_impl = None
|
|
self.disable_compile.__enter__() # type: ignore
|
|
|
|
def __exit__(self, *args) -> None:
|
|
self.disable_compile.__exit__(*args) # type: ignore
|
|
for m, c_impl in zip(self.submodules_compiled, self.compiled_call_impl):
|
|
m._compiled_call_impl = c_impl
|
|
self.compiled_call_impl = []
|
|
|
|
|
|
Probe = AutoProbeD
|
|
|
|
# EXAMPLE USAGE
|
|
d = 512
|
|
seqlen = 4
|
|
bs = 2
|
|
|
|
|
|
class Attention1(nn.Module):
|
|
def forward(self, x):
|
|
attn_bias = fmha.attn_bias.LowerTriangularFromBottomRightMask()
|
|
return fmha.memory_efficient_attention(x, x, x, attn_bias=attn_bias).reshape(
|
|
[x.shape[0], seqlen, -1]
|
|
)
|
|
|
|
|
|
class Attention2(nn.Module):
|
|
def forward(self, x):
|
|
attn_bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens(
|
|
[seqlen] * bs
|
|
).make_causal()
|
|
xr = x.reshape([1, 2 * seqlen, x.shape[2], x.shape[3]])
|
|
return fmha.memory_efficient_attention(xr, xr, xr, attn_bias=attn_bias).reshape(
|
|
[x.shape[0], seqlen, -1]
|
|
)
|
|
|
|
|
|
class AttentionSDPA(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.wo = nn.Linear(d, d)
|
|
|
|
def forward(self, x):
|
|
x = x.transpose(1, 2)
|
|
return self.wo(
|
|
F.scaled_dot_product_attention(x, x, x)
|
|
.transpose(1, 2)
|
|
.reshape([x.shape[0], seqlen, -1])
|
|
)
|
|
|
|
|
|
class AttentionSDPAFlash(AttentionSDPA):
|
|
def forward(self, x):
|
|
x = x.transpose(1, 2)
|
|
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
|
return self.wo(
|
|
F.scaled_dot_product_attention(x, x, x)
|
|
.transpose(1, 2)
|
|
.reshape([x.shape[0], seqlen, -1])
|
|
)
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.head = nn.Linear(d, 16)
|
|
self.trunk = nn.Sequential(
|
|
nn.Linear(d, d),
|
|
nn.Linear(d, d),
|
|
)
|
|
self.q_proj = nn.Linear(d, d, bias=False)
|
|
self.trunk.compile()
|
|
self.attn1 = Attention1()
|
|
self.attn2 = Attention2()
|
|
self.attnSDPA = AttentionSDPA()
|
|
self.attnSDPAflash = AttentionSDPAFlash()
|
|
|
|
def forward(self, x):
|
|
B, nHeads, D = x.shape[0], d // 64, 64
|
|
x = self.q_proj(x).reshape([B, seqlen, nHeads, D])
|
|
x = self.attn1(x) + self.attn2(x) + self.attnSDPA(x) + self.attnSDPAflash(x)
|
|
x = log_stats(x, "attns_out")
|
|
return self.head(self.trunk(x))
|
|
|
|
|
|
def test_masking() -> None:
|
|
q_seqlen = [1, 1, 14, 12]
|
|
kv_seqlen = [2, 2, 14, 18]
|
|
attn_bias = fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
|
|
q_seqlen, kv_seqlen
|
|
).make_causal_from_bottomright()
|
|
logits = torch.randn(
|
|
[1, 1, sum(q_seqlen), sum(kv_seqlen)], dtype=torch.float32, device="cuda"
|
|
)
|
|
bias = attn_bias.materialize(logits.shape, dtype=logits.dtype, device=logits.device)
|
|
logits_masked = logits.clone()
|
|
_mask_attn_logits(
|
|
logits_masked,
|
|
list(range(logits.shape[2])),
|
|
causal=True,
|
|
cu_seqlens_q=attn_bias.q_seqinfo.seqstart,
|
|
cu_seqlens_k=attn_bias.k_seqinfo.seqstart,
|
|
)
|
|
assert (logits + bias == logits_masked).all().item()
|
|
|
|
|
|
def test_toy_model() -> None:
|
|
# Test masking
|
|
kw = dict(device="cuda", dtype=torch.float16)
|
|
x = torch.randn([bs, seqlen, d], **kw)
|
|
m = Model()
|
|
m.head = checkpoint_wrapper(
|
|
m.head, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False
|
|
)
|
|
m.to(**kw)
|
|
m.compile()
|
|
optim = torch.optim.SGD(m.parameters(), lr=0.0)
|
|
probe = AutoProbeD(m, "./probe.json")
|
|
|
|
for i in range(4):
|
|
with contextlib.ExitStack() as stack:
|
|
print(f"########### STEP {i}")
|
|
if i % 4 == 1:
|
|
stack.enter_context(probe)
|
|
probe.metadata = {"it": i}
|
|
y = m(x)
|
|
g = torch.randn_like(y)
|
|
y.backward(g)
|
|
if i % 4 == 1:
|
|
assert probe.enabled
|
|
# Make sure we registered all linears
|
|
print(list(probe.store.keys()))
|
|
for key in [
|
|
"Model::attns_out",
|
|
"Model::attns_out.g",
|
|
"Model.attn1::attn_logits",
|
|
"Model.attn2::attn_logits",
|
|
"Model.attnSDPA::attn_logits",
|
|
"Model.attnSDPAflash::attn_logits",
|
|
"Model.head::w",
|
|
"Model.head::w.g",
|
|
"Model.head::in",
|
|
"Model.head::in.g",
|
|
"Model.head::out",
|
|
"Model.head::out.g",
|
|
"Model.trunk.0::in",
|
|
"Model.trunk.1::in",
|
|
]:
|
|
assert key in probe.store, f"Missing key: '{key}'"
|
|
# .. and that the values are correct
|
|
for key, tensor in [
|
|
("Model.head::w", m.head.weight),
|
|
("Model.head::w.g", m.head.weight.grad),
|
|
("Model.q_proj::in", x),
|
|
("Model.q_proj::w.g", m.q_proj.weight.grad),
|
|
("Model.head::out", y),
|
|
("Model.head::out.g", g),
|
|
]:
|
|
assert key in probe.store, f"Missing key: '{key}'"
|
|
assert torch.allclose(
|
|
probe.store[key]["abs.mean"], tensor.float().abs().mean()
|
|
), f"'{key}' mismatches"
|
|
# Check we don't have `nans`
|
|
for key, value in probe.store.items():
|
|
if "abs.mean" in value:
|
|
assert math.isfinite(
|
|
value["abs.mean"].item()
|
|
), f"Inf/Nan for {key}"
|
|
optim.step()
|
|
optim.zero_grad()
|