diff --git a/ktransformers/ktransformers_ext/triton/fp8gemm.py b/ktransformers/ktransformers_ext/triton/fp8gemm.py new file mode 100644 index 0000000..4da4cfe --- /dev/null +++ b/ktransformers/ktransformers_ext/triton/fp8gemm.py @@ -0,0 +1,191 @@ +from typing import Tuple + +import torch +import triton +import triton.language as tl +from triton import Config + + +@triton.jit +def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): + """ + Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. + + Args: + x_ptr (triton.Pointer): Pointer to the input tensor. + y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored. + s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored. + BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. + + Returns: + None + """ + pid = tl.program_id(axis=0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offs).to(tl.float32) + s = tl.max(tl.abs(x)) / 448. + y = x / s + y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y) + tl.store(s_ptr + pid, s) + + +def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), 'Input tensor must be contiguous' + assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})' + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) + grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), ) + act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) + return y, s + + +@triton.jit +def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): + """ + Dequantizes weights using the provided scaling factors and stores the result. + + Args: + x_ptr (tl.pointer): Pointer to the quantized weights. + s_ptr (tl.pointer): Pointer to the scaling factors. + y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): Size of the block for tiling. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.load(s_ptr + pid_m * n + pid_n) + y = x * s + tl.store(y_ptr + offs, y, mask=mask) + + +def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: + """ + Dequantizes the given weight tensor using the provided scale tensor. + + Args: + x (torch.Tensor): The quantized weight tensor of shape (M, N). + s (torch.Tensor): The scale tensor of shape (M, N). + block_size (int, optional): The block size to use for dequantization. Defaults to 128. + + Returns: + torch.Tensor: The dequantized weight tensor of the same shape as `x`. + + Raises: + AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. + """ + assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous' + assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions' + M, N = x.size() + y = torch.empty_like(x, dtype=torch.get_default_dtype()) + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) + weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) + return y + + +fp8_gemm_configs = [ + Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8) + for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6] +] + +@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K']) +@triton.jit +def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr, + a_s_ptr, b_s_ptr, + M, N: tl.constexpr, K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr): + """ + Performs a matrix multiplication operation on FP8 matrices with scaling factors. + + Args: + a_ptr (tl.tensor): Pointer to the first input matrix A. + b_ptr (tl.tensor): Pointer to the second input matrix B. + c_ptr (tl.tensor): Pointer to the output matrix C. + a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A. + b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B. + M (int): Number of rows in matrix A and C. + N (tl.constexpr): Number of columns in matrix B and C. + K (tl.constexpr): Number of columns in matrix A and rows in matrix B. + BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension. + BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension. + BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0) + a_s = tl.load(a_s_ptrs) + b_s = tl.load(b_s_ptrs) + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + b_s_ptrs += 1 + c = accumulator.to(c_ptr.dtype.element_ty) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + +def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor): + """ + Perform a matrix multiplication using FP8 precision. + + Args: + a (torch.Tensor): The first input matrix, must be contiguous. + a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous. + b (torch.Tensor): The second input matrix, must be contiguous. + b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous. + + Returns: + torch.Tensor: The result of the matrix multiplication. + """ + assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous' + assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous' + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N'])) + fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K) + return c \ No newline at end of file diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index 08a2cca..5aff964 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -25,6 +25,7 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl ) from ktransformers.operators.base_operator import BaseInjectedModule from transformers.configuration_utils import PretrainedConfig +from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant from abc import ABC, abstractmethod import sys, os sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build")) @@ -164,7 +165,65 @@ class KLinearTorch(KLinearBase): if self.has_bias: self.bias = None - +class KLinearFP8(KLinearBase): + marlin_q_w: torch.Tensor + marlin_s: torch.Tensor + g_idx: torch.Tensor + sort_indices: torch.Tensor + has_bias: bool + weight: torch.Tensor + scale_w: torch.Tensor + bias: torch.Tensor + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module = None, + device: str = "cuda", + block_size: int = 128, + **kwargs, + ): + super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) + self.has_bias = False + self.dtype = torch.get_default_dtype() + self.block_size = block_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.to(self.device) + orig_shape = list(x.shape) + orig_dtype = x.dtype + x = x.reshape(-1, orig_shape[-1]) + x_quantized, scale_x = act_quant(x, self.block_size) + y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight.scale) + if self.bias is not None: + y += self.bias + return y.to(orig_dtype).reshape(orig_shape) + + def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): + if device is None: device = self.device + if w is None: + w = self.load_weight(device=device) + if isinstance(w, nn.Parameter): + self.weight = w.to(device) + self.has_bias = False + elif isinstance(w, tuple): + self.weight = w[0].to(device) + self.bias = w[1].to(device) + self.has_bias = True + else: + raise ValueError("Invalid weight type") + self.weight = self.weight.to(device) + if self.has_bias: + self.bias = self.bias.to(device) + + def unload(self): + if self.weight is not None: + self.weight = None + if self.has_bias: + self.bias = None + + class KLinearMarlin(KLinearBase): marlin_q_w: torch.Tensor marlin_s: torch.Tensor diff --git a/ktransformers/tests/triton_fp8gemm_test.py b/ktransformers/tests/triton_fp8gemm_test.py new file mode 100644 index 0000000..bb3801c --- /dev/null +++ b/ktransformers/tests/triton_fp8gemm_test.py @@ -0,0 +1,73 @@ +import torch +import torch.nn.functional as F +from typing import Optional +import pytest +from typing import Tuple, Optional, Literal + +# use dir path +import os +import sys +sys.path.insert(0, "/home/azure/ktransformers") +print(sys.path) +from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant +from safetensors import safe_open + +world_size = 1 +rank = 0 +block_size = 128 +gemm_impl: Literal["bf16", "fp8"] = "bf16" +# Assuming `fp8_gemm`, `act_quant`, `weight_dequant` and other relevant functions are already defined + +def test_fp8_gemm_vs_torch_matmul(): + # Test case 1: Create random matrices of size (M, K) and (K, N) + M, K, N = 64, 128, 256 # Matrix dimensions + x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda') + weight = torch.randn(N, K, dtype=torch.bfloat16, device='cuda') + + # Apply act_quant to both matrices + x_quantized, scale_x = act_quant(x, block_size) + weight_quantized, scale_w = act_quant(weight, block_size) + + # mk continous + x_quantized = x_quantized.contiguous() + weight_quantized = weight_quantized.contiguous() + scale_x = scale_x.contiguous() + scale_w = scale_w.contiguous() + + # Perform fp8_gemm using the quantized tensors + result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight_quantized, scale_w) + + # Perform torch.matmul using the original floating point tensors + result_torch_matmul = torch.matmul(x, weight.T) + print(f'result_torch_matmul: {result_torch_matmul.shape}') + print(f'result_fp8_gemm: {result_fp8_gemm.shape}') + + print(f"result_fp8_gemm:\n {result_fp8_gemm}") + print(f"result_torch_matmul:\n {result_torch_matmul}") + +def test_fp8_gemm_vs_torch_matmul_load(): + file_path = "/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors" + with safe_open(file_path, framework="pt", device=0) as f: + weight = f.get_tensor("model.layers.0.mlp.down_proj.weight") + scale = f.get_tensor("model.layers.0.mlp.down_proj.weight_scale_inv") + + # weight_dequant + weight_dequantized = weight_dequant(weight, scale) + print(f"weight_dequantized: {weight_dequantized.shape}") + N, K = weight_dequantized.shape + M = 64 + x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda') + x_quantized, scale_x = act_quant(x, block_size) + + # Test case 1: quantized x matmal with undequantized weight + result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale) + print(f"result_fp8_gemm:\n {result_fp8_gemm}") + + # Perform torch.matmul using the original floating point tensors + result_torch_matmul = torch.matmul(x, weight_dequantized.to(torch.bfloat16).T) + print(f"result_torch_matmul:\n {result_torch_matmul}") + +if __name__ == "__main__": + test_fp8_gemm_vs_torch_matmul() + test_fp8_gemm_vs_torch_matmul_load() + \ No newline at end of file diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index eaa1a7d..26afd39 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -127,6 +127,7 @@ GGML_BLOCK_SIZES = { "Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2, "Q6_K": 256 // 2 + 256 // 4 + 256 // 16 + 2, "IQ4_XS": 2 + 2 + 256 // 2 + 256 // 64, + "FP8": 1, } GGML_ELEMENTS_PER_BLOCK = { @@ -142,6 +143,7 @@ GGML_ELEMENTS_PER_BLOCK = { "Q5_K": 256, "Q6_K": 256, "IQ4_XS": 256, + "FP8": 1, } DATA_TYPES = { @@ -158,6 +160,7 @@ DATA_TYPES = { "uint64": 10, "int64": 11, "float64": 12, + "FP8": 13, } class GGUFLoader: @@ -393,6 +396,9 @@ def read_value(f, data_type): elem_type, count = struct.unpack("