kvcache-ai-ktransformers/csrc/custom_marlin/test_cuda_graph.py
2025-03-31 22:55:32 +08:00

335 lines
10 KiB
Python

import csv
import torch
import torch.nn as nn
import vLLMMarlin
torch.set_grad_enabled(False)
from utils.marlin_utils import (
MarlinWorkspace,
marlin_quantize,
GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MIN_THREAD_K,
GPTQ_MARLIN_MAX_PARALLEL,
)
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
setup_seed(20241223)
torch.set_grad_enabled(False)
torch.set_default_dtype(torch.bfloat16)
global_dtype=torch.bfloat16
global_device=torch.device("cuda",0)
global_num_cases:int=int(50)
torch.cuda.set_device(0)
torch.backends.cudnn.enabled =True
torch.backends.cudnn.benchmark = True
max_batch_size = 512
max_tp = 8
L2_size = 73728 * 1024
def get_usable_mem():
properties = torch.cuda.get_device_properties(global_device)
#print(f"Total memory: {properties.total_memory / (1024 ** 3):.2f} GB")
allocated_memory = torch.cuda.memory_allocated(global_device)
#print(f"Currently allocated memory: {allocated_memory / (1024 ** 2):.2f} MB")
reserved_memory = torch.cuda.memory_reserved(global_device)
#print(f"Currently reserved memory: {reserved_memory / (1024 ** 2):.2f} MB")
return properties.total_memory - 512 * 1024 ** 2 - allocated_memory# - reserved_memory
def exp_range(start, stop, step = 2):
now = start
while now <= stop:
yield now
now *= step
def timing(func, iters, epochs=100):
#warmup
for idx in range(iters):
func(idx)
torch.cuda.synchronize()
cuda_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(cuda_graph):
for idx in range(iters):
func(idx)
for _ in range(2000):
cuda_graph.replay()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
stream = torch.cuda.Stream()
torch.cuda.synchronize()
#with torch.cuda.stream(stream):
start_event.record()
for _ in range(10):
cuda_graph.replay()
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms0 = start_event.elapsed_time(end_event)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
#with torch.cuda.stream(stream):
start_event.record()
for _ in range(epochs+10):
cuda_graph.replay()
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event) - elapsed_time_ms0
#print(elapsed_time_ms0, elapsed_time_ms)
return elapsed_time_ms/iters/epochs
class LinearMarlin(nn.Linear):
marlin_q_w: torch.Tensor
marlin_s: torch.Tensor
g_idx: torch.Tensor
sort_indices: torch.Tensor
has_bias: bool
def __init__(
self,
in_features,
out_features,
bias = False,
device: str = "cuda",
num_bits: int = 4, # 4-bit/8-bit is supported
group_size: int = 64, # -1, 32, 64, 128
act_order: bool = False,
is_k_full=True,
sms = -1, # sms in GPU
**kwargs,
):
self.padding = False
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
if in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or out_features%GPTQ_MARLIN_MIN_THREAD_K!=0:
#print(f"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding")
self.padding = True
self.orin_in_features = in_features
self.orin_out_features = out_features
in_features = (in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K
out_features = (out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N
#print(f"After padding: in_features={in_features}, out_features={out_features}")
super().__init__(in_features, out_features, bias, device)
self.has_bias = bias
self.device = device
self.num_bits = num_bits
self.group_size = group_size
self.act_order = act_order
# TODO: optimize every shape GEMM
blocks_k, blocks_n = in_features//128, out_features//128
self.sms = sms
self.is_k_full = is_k_full
self.weight.requires_grad = False
self.weight.t_()
# Pack Marlin linear
#w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
# self.weight, self.num_bits, self.group_size, self.act_order
#)
marlin_q_w = torch.randint(int(-1e9), int(1e9), (in_features//16, out_features*2), device=device, dtype=torch.int)
marlin_s = torch.randn((in_features//64, out_features), device=device)
self.workspace = MarlinWorkspace(
self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL, self.device
)
self.marlin_q_w = marlin_q_w
self.marlin_s = marlin_s
self.g_idx = torch.empty((0), dtype=torch.int32, device=self.device)
self.sort_indices = torch.empty((0), dtype=torch.int32, device=self.device)
self.k = self.weight.shape[0]
self.n = self.weight.shape[1]
self.weight = None
"""
print(in_features, out_features)
print(marlin_q_w.shape)
print(marlin_q_w.dtype)
print(marlin_s.shape)
print(marlin_s.dtype)
print(self.workspace.scratch.shape)
print(self.workspace.scratch.dtype)
print(self.g_idx.shape)
print(self.g_idx.dtype)
print(self.sort_indices.shape)
print(self.sort_indices.dtype)
#print(w_ref.shape)
#print(w_ref.dtype)
"""
#w_ref = None
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor) -> torch.Tensor:
# Only support input x as BF16 and FP16
x = x.to(self.device)
orig_shape = list(x.shape)
orig_dtype = x.dtype
x = x.reshape(-1, x.shape[-1])
if self.padding:
padding_input=torch.empty(x.shape[0], self.in_features, device=x.device, dtype=x.dtype)
padding_input[:,:self.orin_in_features] = x
x = padding_input
marlin_s = self.marlin_s.to(x.dtype)
#print(self.sms * ((orig_shape[0]+63)//64))
sms = self.sms
x = vLLMMarlin.gptq_marlin_gemm(
x,
self.marlin_q_w,
marlin_s,
self.g_idx,
self.sort_indices,
self.workspace.scratch,
self.num_bits,
bsz_tensor,
x.shape[0],
self.n,
x.shape[-1],
sms,
self.is_k_full,
)
# TODO: don't padding bias
if self.has_bias:
x = x + self.bias
if self.padding:
x = x[:,:self.orin_out_features]
orig_shape[-1] = self.orin_out_features
else:
orig_shape[-1] = self.out_features
return x.reshape(orig_shape).to(orig_dtype)
def benchLinearMarlin(input_dim, output_dim):#, out_file
print("benchmarking MLP Marlin")
print("-----------------------------------------------------------")
headers = ["batch_size", "tp", "used_time", "bandwidth GB/s", "TFLOPS", "cases", "padding", "sms"]
print(" | ".join(headers) + "\n")
rows = []
for batch_size in exp_range(1, 64):
for tp in exp_range(1, max_tp):
torch.cuda.empty_cache()
if output_dim % tp != 0:
continue
cur_output_dim = output_dim // tp
modules = []
inputs = []
data_size = int(0.53125*input_dim*cur_output_dim)
input_size = int(2*batch_size*input_dim)
output_size = int(2*batch_size*cur_output_dim)
usable_mem = get_usable_mem() - 2 * input_dim * cur_output_dim
min_cases = max(global_num_cases, (2*L2_size) // (data_size+input_size))
cases = int(min(min_cases, (usable_mem * 0.8) // (data_size+input_size)))
#print(usable_mem, data_size, input_size, cases)
bsz_tensor = torch.tensor([batch_size], device=global_device, dtype=torch.int32)
if cases == 0:
row = [f"{batch_size}", "OOM", "OOM", "OOM", "0", "False"]
rows.append(row)
break
for _ in range(cases):
modules.append(LinearMarlin(input_dim, cur_output_dim, sms=56, non_equal_division=False).to(device=global_device).eval())
inputs.append(torch.randn(batch_size, 1, input_dim, device=global_device))
def forward(case_id):
modules[case_id](inputs[case_id], bsz_tensor)
used_time = timing(forward, iters=cases)
bandwidth = (data_size+input_size+output_size)/used_time/1e6
flops = 2*batch_size*input_dim*cur_output_dim
tflops = flops/used_time/1e9
cur_sms = modules[0].sms
row = [f"{batch_size}", f"{tp}", f"{used_time}", f"{bandwidth}", f"{tflops}", f"{cases}", modules[0].padding, cur_sms]
rows.append(row)
print(f"{batch_size}", f"{tp}", f"{used_time}", f"{bandwidth}", f"{tflops}", f"{cases}", modules[0].padding, cur_sms)
"""
with open(out_file, 'w', newline='') as csvfile:
csvwriter = csv.writer(csvfile)
csvwriter.writerow(headers)
for row in rows:
csvwriter.writerow(row)
"""
"""
markdown_table = " | ".join(headers) + "\n"
markdown_table += " | ".join(["---"] * len(headers)) + "\n"
for row in rows:
markdown_table += " | ".join(row) + "\n"
print(markdown_table)
"""
#print("finish write file", out_file)
#print("-------------------------------------------------------------")
if __name__ == "__main__":
benchLinearMarlin(5120, 3584)
exit(0)
max_batch = 1
cur_batch = 1
marlin_linear = LinearMarlin(5120, 3584)
input_tensor = torch.randn(max_batch, 1, 5120, device="cuda", dtype=torch.bfloat16)
bsz_tensor = torch.tensor([max_batch], device="cuda", dtype=torch.int32)
out_truth = marlin_linear(input_tensor, bsz_tensor)
print(out_truth)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
out_buf = marlin_linear(input_tensor, bsz_tensor)
for i in range(10000):
g.replay()
#torch.testing.assert_close(out_buf, out_truth, rtol=1e-3, atol=1e-3)
marlin_linear = LinearMarlin(5120, 3584)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
out_buf = marlin_linear(input_tensor, bsz_tensor)
new_input = torch.randn(cur_batch, 1, 5120, device="cuda", dtype=torch.bfloat16)
bsz_tensor.copy_(torch.tensor([cur_batch], device="cuda", dtype=torch.int32))
new_out_truth = marlin_linear(new_input, bsz_tensor)
input_tensor[:cur_batch].copy_(new_input)
input_tensor[cur_batch:] = 0
g.replay()
torch.cuda.synchronize()
def printMinMax(tensor):
abs_tensor = torch.abs(tensor)
min_val = torch.min(abs_tensor)
max_val = torch.max(abs_tensor)
min_indices = (abs_tensor == min_val).nonzero(as_tuple=True)
max_indices = (abs_tensor == max_val).nonzero(as_tuple=True)
print(f"min: {min_val.item()}")
print(f"min idx: {min_indices}")
print(f"max: {max_val.item()}")
print(f"max idx: {max_indices}")
print(out_buf[:cur_batch].shape)
print(new_out_truth.shape)
printMinMax(out_buf[:cur_batch])
printMinMax(new_out_truth)
#torch.testing.assert_close(out_buf[:cur_batch, 0, :], new_out_truth[:cur_batch, 0, :], rtol=1e-3, atol=1e-3)