mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-06 12:40:02 +00:00
234 lines
No EOL
7.2 KiB
Python
234 lines
No EOL
7.2 KiB
Python
"""This file is used for /tests and /benchmarks"""
|
|
import random
|
|
|
|
import numpy
|
|
import torch
|
|
|
|
from .format24 import (
|
|
mask_creator, sparse_semi_structured_from_dense_cutlass)
|
|
from .marlin_24_perms import (
|
|
marlin_24_perm, marlin_24_scale_perm, marlin_24_scale_perm_single)
|
|
from .marlin_perms import (
|
|
marlin_perm, marlin_scale_perm, marlin_scale_perm_single)
|
|
from .quant_utils import (
|
|
get_pack_factor, quantize_weights, sort_weights, dequantize_weights)
|
|
|
|
|
|
|
|
__cuda_arch = torch.cuda.get_device_capability()
|
|
|
|
MARLIN_TILE = 16
|
|
|
|
GPTQ_MARLIN_TILE = 16
|
|
GPTQ_MARLIN_MIN_THREAD_N = 64
|
|
GPTQ_MARLIN_MIN_THREAD_K = 128
|
|
GPTQ_MARLIN_MAX_PARALLEL = 16
|
|
|
|
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
|
|
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
|
GPTQ_MARLIN_SUPPORTED_SYM = [True]
|
|
|
|
def is_marlin_supported():
|
|
return __cuda_arch[0] >= 8
|
|
|
|
|
|
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE):
|
|
assert q_w.shape == (size_k, size_n)
|
|
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
|
|
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
|
|
|
|
# Permute weights to 16x64 marlin tiles
|
|
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
|
|
q_w = q_w.permute((0, 2, 1, 3))
|
|
q_w = q_w.reshape((size_k // tile, size_n * tile))
|
|
|
|
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
|
|
|
|
return q_w
|
|
|
|
|
|
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
|
|
# Permute
|
|
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
|
|
|
|
# Pack
|
|
pack_factor = get_pack_factor(num_bits)
|
|
orig_device = q_w.device
|
|
|
|
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
|
|
|
q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
|
|
dtype=numpy.uint32)
|
|
for i in range(pack_factor):
|
|
q_packed |= q_w[:, i::pack_factor] << num_bits * i
|
|
|
|
q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
|
|
|
|
return q_packed
|
|
|
|
|
|
def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm,
|
|
scale_perm_single):
|
|
if group_size < size_k and group_size != -1:
|
|
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
|
else:
|
|
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
|
s = s.reshape((-1, size_n)).contiguous()
|
|
|
|
return s
|
|
|
|
|
|
def marlin_quantize(
|
|
w: torch.Tensor,
|
|
num_bits: int,
|
|
group_size: int,
|
|
act_order: bool,
|
|
):
|
|
size_k, size_n = w.shape
|
|
|
|
# Normalize group_size
|
|
if group_size == -1:
|
|
group_size = size_k
|
|
assert group_size <= size_k
|
|
|
|
# Quantize (and apply act_order if provided)
|
|
w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,
|
|
act_order)
|
|
|
|
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
|
# increasing
|
|
sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
|
|
if act_order:
|
|
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
|
|
|
# Reformat to marlin
|
|
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits,
|
|
marlin_perm[num_bits])
|
|
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size,
|
|
marlin_scale_perm[num_bits],
|
|
marlin_scale_perm_single[num_bits])
|
|
|
|
# Create result
|
|
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
|
|
for i in range(len(res_list)):
|
|
res_list[i] = res_list[i].to(w.device)
|
|
|
|
return res_list
|
|
|
|
|
|
def inject_24(w, size_k, size_n):
|
|
assert w.shape == (size_k, size_n)
|
|
|
|
mask = mask_creator(w.t()).t().cuda().bool()
|
|
|
|
return (mask * w).contiguous(), mask.contiguous()
|
|
|
|
|
|
def check_24(w, num_rows_to_sample=50, _verbose=False):
|
|
BLOCK_SIZE = 4
|
|
MAX_NON_ZEROS = 2
|
|
|
|
w = w.t().contiguous()
|
|
|
|
print("check_24: w.shape = {}".format(w.shape))
|
|
|
|
num_rows, num_cols = w.shape
|
|
sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
|
|
if _verbose:
|
|
print(f"Sampled row idxs = {sampled_row_idxs}")
|
|
|
|
total_segments = 0
|
|
non_24_segments = 0
|
|
for i in sampled_row_idxs:
|
|
for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
|
|
total_segments += 1
|
|
block = w[i, j:j + BLOCK_SIZE]
|
|
num_nonzero = torch.count_nonzero(block)
|
|
if num_nonzero > MAX_NON_ZEROS:
|
|
print("i = {} j = {} block = {}".format(i, j, block))
|
|
non_24_segments += 1
|
|
|
|
print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
|
|
|
|
|
|
def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
|
|
assert q_24.shape == (size_k, size_n)
|
|
|
|
# Remove zp to normalize over 0
|
|
max_q_val = (1 << num_bits) - 1
|
|
zp = (max_q_val + 1) // 2
|
|
q_24_no_zp = q_24 - zp
|
|
|
|
# Compress
|
|
q_24_no_zp = q_24_no_zp.t().contiguous()
|
|
q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(
|
|
q_24_no_zp)
|
|
q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
|
|
|
|
# Restore zp
|
|
q_24_comp = q_24_no_zp_comp + zp
|
|
|
|
# Resize meta to its actual shape (without moving any data)
|
|
meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
|
|
|
|
return q_24_comp, meta
|
|
|
|
|
|
def marlin_24_quantize(
|
|
w: torch.Tensor,
|
|
num_bits: int,
|
|
group_size: int,
|
|
):
|
|
size_k, size_n = w.shape
|
|
|
|
# Normalize group_size
|
|
if group_size == -1:
|
|
group_size = size_k
|
|
assert group_size <= size_k
|
|
|
|
# Inject 2:4 sparsity
|
|
w_24, mask_24 = inject_24(w, size_k, size_n)
|
|
|
|
# Quantize
|
|
w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24,
|
|
num_bits,
|
|
group_size,
|
|
act_order=False)
|
|
|
|
# Compress quantized weight
|
|
q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n,
|
|
num_bits)
|
|
size_k_comp = size_k // 2
|
|
|
|
# Reformat to marlin
|
|
marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n,
|
|
num_bits, marlin_24_perm[num_bits])
|
|
marlin_24_s = marlin_permute_scales(s, size_k, size_n, group_size,
|
|
marlin_24_scale_perm[num_bits],
|
|
marlin_24_scale_perm_single[num_bits])
|
|
|
|
# Create result
|
|
res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
|
|
for i in range(len(res_list)):
|
|
res_list[i] = res_list[i].to(w.device)
|
|
|
|
return res_list
|
|
|
|
|
|
def compute_max_diff(output, output_ref):
|
|
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
|
torch.abs(output_ref))
|
|
|
|
|
|
class MarlinWorkspace:
|
|
|
|
def __init__(self, out_features, min_thread_n, max_parallel, device):
|
|
assert (out_features % min_thread_n == 0), (
|
|
"out_features = {} is undivisible by min_thread_n = {}".format(
|
|
out_features, min_thread_n))
|
|
|
|
max_workspace_size = ((out_features // min_thread_n) * max_parallel)
|
|
|
|
self.scratch = torch.zeros(max_workspace_size,
|
|
dtype=torch.int,
|
|
device=device) |