mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 20:19:51 +00:00
116 lines
No EOL
4.2 KiB
Python
116 lines
No EOL
4.2 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
from typing import Optional
|
|
import pytest
|
|
from typing import Tuple, Optional, Literal
|
|
import time
|
|
# use dir path
|
|
import os
|
|
import sys
|
|
sys.path.insert(0, "/home/azure/ktransformers")
|
|
print(sys.path)
|
|
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
|
|
from safetensors import safe_open
|
|
|
|
world_size = 1
|
|
rank = 0
|
|
block_size = 128
|
|
gemm_impl: Literal["bf16", "fp8"] = "bf16"
|
|
# Assuming `fp8_gemm`, `act_quant`, `weight_dequant` and other relevant functions are already defined
|
|
|
|
def test_fp8_gemm_vs_torch_matmul():
|
|
# Test case 1: Create random matrices of size (M, K) and (K, N)
|
|
M, K, N = 64, 128, 256 # Matrix dimensions
|
|
x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
|
|
weight = torch.randn(N, K, dtype=torch.bfloat16, device='cuda')
|
|
|
|
# Apply act_quant to both matrices
|
|
x_quantized, scale_x = act_quant(x, block_size)
|
|
weight_quantized, scale_w = act_quant(weight, block_size)
|
|
|
|
# mk continous
|
|
x_quantized = x_quantized.contiguous()
|
|
weight_quantized = weight_quantized.contiguous()
|
|
scale_x = scale_x.contiguous()
|
|
scale_w = scale_w.contiguous()
|
|
|
|
# Perform fp8_gemm using the quantized tensors
|
|
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight_quantized, scale_w)
|
|
|
|
# Perform torch.matmul using the original floating point tensors
|
|
result_torch_matmul = torch.matmul(x, weight.T)
|
|
print(f'result_torch_matmul: {result_torch_matmul.shape}')
|
|
print(f'result_fp8_gemm: {result_fp8_gemm.shape}')
|
|
|
|
print(f"result_fp8_gemm:\n {result_fp8_gemm}")
|
|
print(f"result_torch_matmul:\n {result_torch_matmul}")
|
|
|
|
def test_fp8_gemm_vs_torch_matmul_load():
|
|
file_path = "/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors"
|
|
with safe_open(file_path, framework="pt", device=0) as f:
|
|
weight = f.get_tensor("model.layers.0.mlp.down_proj.weight")
|
|
scale = f.get_tensor("model.layers.0.mlp.down_proj.weight_scale_inv")
|
|
|
|
# weight_dequant
|
|
weight_dequantized = weight_dequant(weight, scale)
|
|
print(f"weight_dequantized: {weight_dequantized.shape}")
|
|
N, K = weight_dequantized.shape
|
|
M = 64
|
|
x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda')
|
|
x_quantized, scale_x = act_quant(x, block_size)
|
|
|
|
# Test case 1: quantized x matmal with undequantized weight
|
|
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
|
|
print(f"result_fp8_gemm:\n {result_fp8_gemm}")
|
|
print(f"dtype {result_fp8_gemm.dtype}")
|
|
|
|
# Perform torch.matmul using the original floating point tensors
|
|
result_torch_matmul = torch.matmul(x, weight_dequantized.to(torch.bfloat16).T)
|
|
print(f"result_torch_matmul:\n {result_torch_matmul}")
|
|
|
|
def test_fp8_gemm_tplops():
|
|
file_path = "/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors"
|
|
with safe_open(file_path, framework="pt", device=0) as f:
|
|
weight = f.get_tensor("model.layers.0.mlp.down_proj.weight")
|
|
scale = f.get_tensor("model.layers.0.mlp.down_proj.weight_scale_inv")
|
|
|
|
# weight_dequant
|
|
weight_dequantized = weight_dequant(weight, scale)
|
|
print(f"weight_dequantized: {weight_dequantized.shape}")
|
|
N, K = weight_dequantized.shape
|
|
M = 6400
|
|
x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda')
|
|
# x_quantized, scale_x = act_quant(x, block_size)
|
|
|
|
# Calculate time for 1000 fp8_gemm
|
|
i = 10
|
|
flops_per_gemm = 2 * M * N * K
|
|
total_flops = i * flops_per_gemm
|
|
|
|
x_quantized, scale_x = act_quant(x, block_size)
|
|
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
|
|
x_quantized, scale_x = act_quant(x, block_size)
|
|
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
|
|
|
|
|
|
t0 = time.time()
|
|
torch.cuda.synchronize()
|
|
for i in range(i):
|
|
x_quantized, scale_x = act_quant(x, block_size)
|
|
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
|
|
torch.cuda.synchronize()
|
|
t1 = time.time()
|
|
|
|
total_time = t1 - t0
|
|
tflops = total_flops / total_time / 1e12
|
|
print(f"total_time: {total_time}")
|
|
print(f"tflops: {tflops}")
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_fp8_gemm_vs_torch_matmul()
|
|
test_fp8_gemm_vs_torch_matmul_load()
|
|
test_fp8_gemm_tplops()
|
|
|