mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 16:37:46 +00:00
153 lines
4.8 KiB
Python
153 lines
4.8 KiB
Python
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||
|
|
||
|
import re
|
||
|
import warnings
|
||
|
from typing import Callable
|
||
|
|
||
|
import torch
|
||
|
|
||
|
# avoid division by zero when calculating scale
|
||
|
EPS = 1e-12
|
||
|
|
||
|
|
||
|
def scale(t, amax_t, dtype_t):
|
||
|
min_v, max_v = torch.finfo(dtype_t).min, torch.finfo(dtype_t).max
|
||
|
scale_t = torch.clamp(amax_t.float(), min=EPS) / max_v
|
||
|
t_fp8 = (t / scale_t).clamp(min=min_v, max=max_v).to(dtype_t)
|
||
|
return t_fp8, scale_t
|
||
|
|
||
|
|
||
|
def matmul(
|
||
|
first, amax_first, dtype_first, second_t, amax_second_t, dtype_second_t, bias
|
||
|
):
|
||
|
first_fp8, scale_first = scale(first, amax_first, dtype_first)
|
||
|
second_t_fp8, scale_second_t = scale(second_t, amax_second_t, dtype_second_t)
|
||
|
output = torch._scaled_mm(
|
||
|
first_fp8,
|
||
|
second_t_fp8.t(),
|
||
|
scale_a=scale_first,
|
||
|
scale_b=scale_second_t.t(),
|
||
|
bias=bias,
|
||
|
out_dtype=torch.bfloat16,
|
||
|
use_fast_accum=True,
|
||
|
)
|
||
|
return output
|
||
|
|
||
|
|
||
|
@torch._dynamo.allow_in_graph
|
||
|
class Fp8LinearFn(torch.autograd.Function):
|
||
|
@staticmethod
|
||
|
def forward(ctx, a, b_t, bias):
|
||
|
amax_a = a.abs().amax(dim=-1, keepdim=True)
|
||
|
amax_b_t = b_t.abs().amax(dim=-1, keepdim=True)
|
||
|
out = matmul(
|
||
|
a, amax_a, torch.float8_e4m3fn, b_t, amax_b_t, torch.float8_e4m3fn, bias
|
||
|
)
|
||
|
|
||
|
ctx.a_requires_grad = a.requires_grad
|
||
|
ctx.b_requires_grad = b_t.requires_grad
|
||
|
ctx.bias_requires_grad = bias.requires_grad if bias is not None else False
|
||
|
|
||
|
ctx.save_for_backward(a, b_t, amax_b_t.max())
|
||
|
|
||
|
return out
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_out):
|
||
|
a, b_t, amax_b = ctx.saved_tensors
|
||
|
|
||
|
if ctx.a_requires_grad:
|
||
|
b = b_t.t().contiguous()
|
||
|
amax_grad_out = grad_out.abs().amax(dim=-1, keepdim=True)
|
||
|
amax_b = amax_b.repeat(b.shape[0], 1)
|
||
|
grad_a = matmul(
|
||
|
grad_out,
|
||
|
amax_grad_out,
|
||
|
torch.float8_e4m3fn,
|
||
|
b,
|
||
|
amax_b,
|
||
|
torch.float8_e4m3fn,
|
||
|
None,
|
||
|
)
|
||
|
else:
|
||
|
grad_a = None
|
||
|
if ctx.b_requires_grad:
|
||
|
grad_b = grad_out.t() @ a
|
||
|
else:
|
||
|
grad_b = None
|
||
|
if ctx.bias_requires_grad:
|
||
|
grad_bias = grad_out.sum(dim=0)
|
||
|
else:
|
||
|
grad_bias = None
|
||
|
|
||
|
return grad_a, grad_b, grad_bias
|
||
|
|
||
|
|
||
|
class Fp8Linear(torch.nn.Linear):
|
||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||
|
out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, self.bias)
|
||
|
out = out.unflatten(0, input.shape[:-1])
|
||
|
return out
|
||
|
|
||
|
|
||
|
def named_replace(
|
||
|
fn: Callable[[torch.nn.Module, str], torch.nn.Module],
|
||
|
module: torch.nn.Module,
|
||
|
name="",
|
||
|
) -> torch.nn.Module:
|
||
|
for child_name, child_module in list(module.named_children()):
|
||
|
full_name = f"{name}.{child_name}" if name else child_name
|
||
|
new_child_module = named_replace(fn, child_module, full_name)
|
||
|
setattr(module, child_name, new_child_module)
|
||
|
module = fn(module, name)
|
||
|
return module
|
||
|
|
||
|
|
||
|
def convert_linears_to_fp8(
|
||
|
root_module: torch.nn.Module, recipe: str, filter: str
|
||
|
) -> torch.nn.Module:
|
||
|
if recipe not in ["rowwise"]:
|
||
|
raise RuntimeError(f"Unknown float8 recipe {recipe!r}")
|
||
|
|
||
|
if recipe == "rowwise" and torch.__version__ < "2.5":
|
||
|
# We need https://github.com/pytorch/pytorch/pull/134781.
|
||
|
warnings.warn("Float8 row-wise scaling is slow in PyTorch prior to v2.5.0")
|
||
|
|
||
|
# Multi-kernel makes Inductor auto-tune between a regular "streaming"-based
|
||
|
# reduction kernel and a "persistent" reduction kernel. Since fp8 has some
|
||
|
# multi-pass steps (e.g., first get amax, then scale), persistent kernels
|
||
|
# should perform better.
|
||
|
torch._inductor.config.triton.multi_kernel = 1
|
||
|
|
||
|
filter_re = re.compile(filter)
|
||
|
|
||
|
def replace(module: torch.nn.Module, name: str) -> torch.nn.Module:
|
||
|
if not isinstance(module, torch.nn.Linear) or not filter_re.search(name):
|
||
|
return module
|
||
|
if type(module) == torch.nn.Linear:
|
||
|
if recipe == "rowwise":
|
||
|
new_module = Fp8Linear(
|
||
|
in_features=module.in_features,
|
||
|
out_features=module.out_features,
|
||
|
bias=module.bias is not None,
|
||
|
dtype=module.weight.dtype,
|
||
|
device=module.weight.device,
|
||
|
)
|
||
|
new_module.weight = module.weight
|
||
|
new_module.bias = module.bias
|
||
|
else:
|
||
|
assert False, recipe
|
||
|
else:
|
||
|
assert False, str(type(module))
|
||
|
return new_module
|
||
|
|
||
|
out = named_replace(replace, root_module)
|
||
|
|
||
|
# Force re-compile everything
|
||
|
torch._dynamo.reset_code_caches()
|
||
|
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
|
||
|
|
||
|
reset_cudagraph_trees()
|
||
|
|
||
|
return out
|