blt/bytelatent/float8.py
2024-12-12 15:32:30 -08:00

153 lines
4.8 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
import re
import warnings
from typing import Callable
import torch
# avoid division by zero when calculating scale
EPS = 1e-12
def scale(t, amax_t, dtype_t):
min_v, max_v = torch.finfo(dtype_t).min, torch.finfo(dtype_t).max
scale_t = torch.clamp(amax_t.float(), min=EPS) / max_v
t_fp8 = (t / scale_t).clamp(min=min_v, max=max_v).to(dtype_t)
return t_fp8, scale_t
def matmul(
first, amax_first, dtype_first, second_t, amax_second_t, dtype_second_t, bias
):
first_fp8, scale_first = scale(first, amax_first, dtype_first)
second_t_fp8, scale_second_t = scale(second_t, amax_second_t, dtype_second_t)
output = torch._scaled_mm(
first_fp8,
second_t_fp8.t(),
scale_a=scale_first,
scale_b=scale_second_t.t(),
bias=bias,
out_dtype=torch.bfloat16,
use_fast_accum=True,
)
return output
@torch._dynamo.allow_in_graph
class Fp8LinearFn(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b_t, bias):
amax_a = a.abs().amax(dim=-1, keepdim=True)
amax_b_t = b_t.abs().amax(dim=-1, keepdim=True)
out = matmul(
a, amax_a, torch.float8_e4m3fn, b_t, amax_b_t, torch.float8_e4m3fn, bias
)
ctx.a_requires_grad = a.requires_grad
ctx.b_requires_grad = b_t.requires_grad
ctx.bias_requires_grad = bias.requires_grad if bias is not None else False
ctx.save_for_backward(a, b_t, amax_b_t.max())
return out
@staticmethod
def backward(ctx, grad_out):
a, b_t, amax_b = ctx.saved_tensors
if ctx.a_requires_grad:
b = b_t.t().contiguous()
amax_grad_out = grad_out.abs().amax(dim=-1, keepdim=True)
amax_b = amax_b.repeat(b.shape[0], 1)
grad_a = matmul(
grad_out,
amax_grad_out,
torch.float8_e4m3fn,
b,
amax_b,
torch.float8_e4m3fn,
None,
)
else:
grad_a = None
if ctx.b_requires_grad:
grad_b = grad_out.t() @ a
else:
grad_b = None
if ctx.bias_requires_grad:
grad_bias = grad_out.sum(dim=0)
else:
grad_bias = None
return grad_a, grad_b, grad_bias
class Fp8Linear(torch.nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, self.bias)
out = out.unflatten(0, input.shape[:-1])
return out
def named_replace(
fn: Callable[[torch.nn.Module, str], torch.nn.Module],
module: torch.nn.Module,
name="",
) -> torch.nn.Module:
for child_name, child_module in list(module.named_children()):
full_name = f"{name}.{child_name}" if name else child_name
new_child_module = named_replace(fn, child_module, full_name)
setattr(module, child_name, new_child_module)
module = fn(module, name)
return module
def convert_linears_to_fp8(
root_module: torch.nn.Module, recipe: str, filter: str
) -> torch.nn.Module:
if recipe not in ["rowwise"]:
raise RuntimeError(f"Unknown float8 recipe {recipe!r}")
if recipe == "rowwise" and torch.__version__ < "2.5":
# We need https://github.com/pytorch/pytorch/pull/134781.
warnings.warn("Float8 row-wise scaling is slow in PyTorch prior to v2.5.0")
# Multi-kernel makes Inductor auto-tune between a regular "streaming"-based
# reduction kernel and a "persistent" reduction kernel. Since fp8 has some
# multi-pass steps (e.g., first get amax, then scale), persistent kernels
# should perform better.
torch._inductor.config.triton.multi_kernel = 1
filter_re = re.compile(filter)
def replace(module: torch.nn.Module, name: str) -> torch.nn.Module:
if not isinstance(module, torch.nn.Linear) or not filter_re.search(name):
return module
if type(module) == torch.nn.Linear:
if recipe == "rowwise":
new_module = Fp8Linear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
dtype=module.weight.dtype,
device=module.weight.device,
)
new_module.weight = module.weight
new_module.bias = module.bias
else:
assert False, recipe
else:
assert False, str(type(module))
return new_module
out = named_replace(replace, root_module)
# Force re-compile everything
torch._dynamo.reset_code_caches()
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
reset_cudagraph_trees()
return out