Add fp8 linear kernel;\n Add empty cache to fit in 16G VRAM; By 'wkGCaSS - 知乎 https://zhuanlan.zhihu.com/p/25491611225'

This commit is contained in:
Azure 2025-02-22 13:05:08 +00:00
parent b4fb633991
commit 7b7c6a657d
5 changed files with 331 additions and 2 deletions

View file

@ -0,0 +1,191 @@
from typing import Tuple
import torch
import triton
import triton.language as tl
from triton import Config
@triton.jit
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
"""
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
Args:
x_ptr (triton.Pointer): Pointer to the input tensor.
y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
Returns:
None
"""
pid = tl.program_id(axis=0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offs).to(tl.float32)
s = tl.max(tl.abs(x)) / 448.
y = x / s
y = y.to(y_ptr.dtype.element_ty)
tl.store(y_ptr + offs, y)
tl.store(s_ptr + pid, s)
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantizes the input tensor `x` using block-wise quantization.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`.
"""
assert x.is_contiguous(), 'Input tensor must be contiguous'
assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})'
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
return y, s
@triton.jit
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
"""
Dequantizes weights using the provided scaling factors and stores the result.
Args:
x_ptr (tl.pointer): Pointer to the quantized weights.
s_ptr (tl.pointer): Pointer to the scaling factors.
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
M (int): Number of rows in the weight matrix.
N (int): Number of columns in the weight matrix.
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
Returns:
None
"""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
n = tl.cdiv(N, BLOCK_SIZE)
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs = offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
s = tl.load(s_ptr + pid_m * n + pid_n)
y = x * s
tl.store(y_ptr + offs, y, mask=mask)
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
"""
Dequantizes the given weight tensor using the provided scale tensor.
Args:
x (torch.Tensor): The quantized weight tensor of shape (M, N).
s (torch.Tensor): The scale tensor of shape (M, N).
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
Returns:
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
Raises:
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
"""
assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous'
assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions'
M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype())
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y
fp8_gemm_configs = [
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
]
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
@triton.jit
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
a_s_ptr, b_s_ptr,
M, N: tl.constexpr, K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr):
"""
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
Args:
a_ptr (tl.tensor): Pointer to the first input matrix A.
b_ptr (tl.tensor): Pointer to the second input matrix B.
c_ptr (tl.tensor): Pointer to the output matrix C.
a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
M (int): Number of rows in matrix A and C.
N (tl.constexpr): Number of columns in matrix B and C.
K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
Returns:
None
"""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
k = tl.cdiv(K, BLOCK_SIZE_K)
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
a_s_ptrs = a_s_ptr + offs_m * k
b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(k):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
a_s = tl.load(a_s_ptrs)
b_s = tl.load(b_s_ptrs)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
a_s_ptrs += 1
b_s_ptrs += 1
c = accumulator.to(c_ptr.dtype.element_ty)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, c, mask=mask)
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
"""
Perform a matrix multiplication using FP8 precision.
Args:
a (torch.Tensor): The first input matrix, must be contiguous.
a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
b (torch.Tensor): The second input matrix, must be contiguous.
b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
Returns:
torch.Tensor: The result of the matrix multiplication.
"""
assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous'
assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'
K = a.size(-1)
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
return c

View file

@ -25,6 +25,7 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl
)
from ktransformers.operators.base_operator import BaseInjectedModule
from transformers.configuration_utils import PretrainedConfig
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
from abc import ABC, abstractmethod
import sys, os
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
@ -164,6 +165,64 @@ class KLinearTorch(KLinearBase):
if self.has_bias:
self.bias = None
class KLinearFP8(KLinearBase):
marlin_q_w: torch.Tensor
marlin_s: torch.Tensor
g_idx: torch.Tensor
sort_indices: torch.Tensor
has_bias: bool
weight: torch.Tensor
scale_w: torch.Tensor
bias: torch.Tensor
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
device: str = "cuda",
block_size: int = 128,
**kwargs,
):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
self.has_bias = False
self.dtype = torch.get_default_dtype()
self.block_size = block_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.to(self.device)
orig_shape = list(x.shape)
orig_dtype = x.dtype
x = x.reshape(-1, orig_shape[-1])
x_quantized, scale_x = act_quant(x, self.block_size)
y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight.scale)
if self.bias is not None:
y += self.bias
return y.to(orig_dtype).reshape(orig_shape)
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if device is None: device = self.device
if w is None:
w = self.load_weight(device=device)
if isinstance(w, nn.Parameter):
self.weight = w.to(device)
self.has_bias = False
elif isinstance(w, tuple):
self.weight = w[0].to(device)
self.bias = w[1].to(device)
self.has_bias = True
else:
raise ValueError("Invalid weight type")
self.weight = self.weight.to(device)
if self.has_bias:
self.bias = self.bias.to(device)
def unload(self):
if self.weight is not None:
self.weight = None
if self.has_bias:
self.bias = None
class KLinearMarlin(KLinearBase):
marlin_q_w: torch.Tensor

View file

@ -0,0 +1,73 @@
import torch
import torch.nn.functional as F
from typing import Optional
import pytest
from typing import Tuple, Optional, Literal
# use dir path
import os
import sys
sys.path.insert(0, "/home/azure/ktransformers")
print(sys.path)
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
from safetensors import safe_open
world_size = 1
rank = 0
block_size = 128
gemm_impl: Literal["bf16", "fp8"] = "bf16"
# Assuming `fp8_gemm`, `act_quant`, `weight_dequant` and other relevant functions are already defined
def test_fp8_gemm_vs_torch_matmul():
# Test case 1: Create random matrices of size (M, K) and (K, N)
M, K, N = 64, 128, 256 # Matrix dimensions
x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
weight = torch.randn(N, K, dtype=torch.bfloat16, device='cuda')
# Apply act_quant to both matrices
x_quantized, scale_x = act_quant(x, block_size)
weight_quantized, scale_w = act_quant(weight, block_size)
# mk continous
x_quantized = x_quantized.contiguous()
weight_quantized = weight_quantized.contiguous()
scale_x = scale_x.contiguous()
scale_w = scale_w.contiguous()
# Perform fp8_gemm using the quantized tensors
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight_quantized, scale_w)
# Perform torch.matmul using the original floating point tensors
result_torch_matmul = torch.matmul(x, weight.T)
print(f'result_torch_matmul: {result_torch_matmul.shape}')
print(f'result_fp8_gemm: {result_fp8_gemm.shape}')
print(f"result_fp8_gemm:\n {result_fp8_gemm}")
print(f"result_torch_matmul:\n {result_torch_matmul}")
def test_fp8_gemm_vs_torch_matmul_load():
file_path = "/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors"
with safe_open(file_path, framework="pt", device=0) as f:
weight = f.get_tensor("model.layers.0.mlp.down_proj.weight")
scale = f.get_tensor("model.layers.0.mlp.down_proj.weight_scale_inv")
# weight_dequant
weight_dequantized = weight_dequant(weight, scale)
print(f"weight_dequantized: {weight_dequantized.shape}")
N, K = weight_dequantized.shape
M = 64
x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
x_quantized, scale_x = act_quant(x, block_size)
# Test case 1: quantized x matmal with undequantized weight
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
print(f"result_fp8_gemm:\n {result_fp8_gemm}")
# Perform torch.matmul using the original floating point tensors
result_torch_matmul = torch.matmul(x, weight_dequantized.to(torch.bfloat16).T)
print(f"result_torch_matmul:\n {result_torch_matmul}")
if __name__ == "__main__":
test_fp8_gemm_vs_torch_matmul()
test_fp8_gemm_vs_torch_matmul_load()

View file

@ -127,6 +127,7 @@ GGML_BLOCK_SIZES = {
"Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2,
"Q6_K": 256 // 2 + 256 // 4 + 256 // 16 + 2,
"IQ4_XS": 2 + 2 + 256 // 2 + 256 // 64,
"FP8": 1,
}
GGML_ELEMENTS_PER_BLOCK = {
@ -142,6 +143,7 @@ GGML_ELEMENTS_PER_BLOCK = {
"Q5_K": 256,
"Q6_K": 256,
"IQ4_XS": 256,
"FP8": 1,
}
DATA_TYPES = {
@ -158,6 +160,7 @@ DATA_TYPES = {
"uint64": 10,
"int64": 11,
"float64": 12,
"FP8": 13,
}
class GGUFLoader:
@ -393,6 +396,9 @@ def read_value(f, data_type):
elem_type, count = struct.unpack("<IQ", f.read(4 + 8))
return [read_value(f, elem_type) for _ in range(count)]
elif data_type == DATA_TYPES["FP8"]:
return struct.unpack("<B", f.read(1))[0]
else:
raise NotImplementedError(f"Data type {data_type} not implemented")

View file

@ -70,7 +70,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
target_dtype = torch.get_default_dtype()
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
print(f"loading {translated_key} to {device}")
# device = "cpu" if "embd" in translated_key else "cuda"
torch.cuda.empty_cache() # To fit in 16G VRAM. By "wkGCaSS - 知乎 https://zhuanlan.zhihu.com/p/25491611225"
weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
set_param(module, name, weights)
del weights