[feature] support q2_k & q3_k dequantize on gpu

This commit is contained in:
BITcyman 2024-08-12 12:53:12 +00:00
parent 650c368c18
commit 7c4cb520bd
5 changed files with 161 additions and 12 deletions

View file

@ -4,7 +4,7 @@
* @Date : 2024-07-25 13:38:30
* @Version : 1.0.0
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-09 01:45:02
* @LastEditTime : 2024-08-12 03:05:04
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
@ -27,6 +27,10 @@ PYBIND11_MODULE(KTransformersOps, m) {
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.",
py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"),
py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"),

View file

@ -13,6 +13,7 @@ int test(){
torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device);
PYBIND11_MODULE(cudaops, m) {
m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.",
@ -23,6 +24,10 @@ PYBIND11_MODULE(cudaops, m) {
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("test", &test, "Function to test.");
}

View file

@ -4,7 +4,7 @@
* @Date : 2024-07-25 13:38:30
* @Version : 1.0.0
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-09 07:57:06
* @LastEditTime : 2024-08-12 04:18:04
* Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c
* Copyright (c) 2023-2024 The ggml authors
* Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
@ -36,6 +36,97 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
}
}
__global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 80)));
const float min = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 82)));
const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16);
int is = 0;
float dl, ml;
for (int n = 0; n < 256; n += 128) {
int shift = 0;
for (int j = 0; j < 4; ++j) {
uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++));
uint8_t sc = *scales;
dl = d * (sc & 0xF); ml = min * (sc >> 4);
for (int l = 0; l < 16; ++l) *output_blk++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
scales = (uint8_t*)(data + block_id * blk_size + (is++));
sc = *scales;
dl = d * (sc & 0xF); ml = min * (sc >> 4);
for (int l = 0; l < 16; ++l) *output_blk++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
shift += 2;
}
q += 32;
}
}
}
__global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const uint32_t kmask1 = 0x03030303;
const uint32_t kmask2 = 0x0f0f0f0f;
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256);
uint32_t aux[4];
const int8_t * scales = (const int8_t*)aux;
const float d_all = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 108)));
const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 32);
const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0);
uint8_t m = 1;
uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96);
for (int i = 0; i < 3; i++) {
aux[i] = 0;
for (int j = 0; j < 4; j++) {
aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8);
}
}
uint32_t tmp = aux[2];
aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
int is = 0;
float dl;
for (int n = 0; n < 256; n += 128) {
int shift = 0;
for (int j = 0; j < 4; ++j) {
dl = d_all * (scales[is++] - 32);
for (int l = 0; l < 16; ++l) {
*output_blk++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));
}
dl = d_all * (scales[is++] - 32);
for (int l = 0; l < 16; ++l) {
*output_blk++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4));
}
shift += 2;
m <<= 1;
}
q += 32;
}
}
}
__global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
@ -176,6 +267,24 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
return output;
}
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device) {
int num_blocks = data.numel() / blk_size;
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options);
data_gpu.copy_(data, false);
// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device));
// Launch kernel
dequantize_q5_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
cudaDeviceSynchronize();
return output;
}
torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device) {
// data.numel%blk_size should be 0, else raise err
int num_blocks = data.numel() / blk_size;
@ -196,8 +305,7 @@ torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device de
return output;
}
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device) {
torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device) {
int num_blocks = data.numel() / blk_size;
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
@ -209,7 +317,25 @@ torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device de
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device));
// Launch kernel
dequantize_q5_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
dequantize_q3_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
cudaDeviceSynchronize();
return output;
}
torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device) {
int num_blocks = data.numel() / blk_size;
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options);
data_gpu.copy_(data, false);
// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device));
// Launch kernel
dequantize_q2_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
cudaDeviceSynchronize();
return output;

View file

@ -4,7 +4,7 @@
* @Date : 2024-07-22 09:27:55
* @Version : 1.0.0
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-09 01:44:21
* @LastEditTime : 2024-08-12 03:48:46
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#pragma once
@ -16,4 +16,6 @@
torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device);

View file

@ -6,7 +6,7 @@ Author : Azure-Tang, Boxin Zhang, chenht2022
Date : 2024-07-26 08:48:54
Version : 1.0.0
LastEditors : kkk1nak0
LastEditTime : 2024-08-09 08:03:44
LastEditTime : 2024-08-12 07:21:55
Adapted from https://github.com/99991/pygguf/blob/main/gguf.py
Copyright (c) 2023-2024 The ggml authors
Copyright (c) 2024 Thomas Germer
@ -390,8 +390,14 @@ def dequantize_q2_k(data):
return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
def dequantize_q2_k_gpu(data):
raise NotImplementedError()
def dequantize_q2_k_gpu(data, device:str ="cuda"):
block_size = GGML_BLOCK_SIZES["Q2_K"]
data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data)
return KTransformersOps.dequantize_q2_k(data, block_size, device)
def dequantize_q3_k(data):
# C implementation
@ -435,8 +441,14 @@ def dequantize_q3_k(data):
(((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7])
], axis=1)
def dequantize_q3_k_gpu(data):
raise NotImplementedError()
def dequantize_q3_k_gpu(data, device:str ="cuda"):
block_size = GGML_BLOCK_SIZES["Q3_K"]
data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data)
return KTransformersOps.dequantize_q3_k(data, block_size, device)
def dequantize_q4_k(data):
# C implementation