mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
[feature] support q2_k & q3_k dequantize on gpu
This commit is contained in:
parent
650c368c18
commit
7c4cb520bd
5 changed files with 161 additions and 12 deletions
|
@ -4,7 +4,7 @@
|
|||
* @Date : 2024-07-25 13:38:30
|
||||
* @Version : 1.0.0
|
||||
* @LastEditors : kkk1nak0
|
||||
* @LastEditTime : 2024-08-09 01:45:02
|
||||
* @LastEditTime : 2024-08-12 03:05:04
|
||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
**/
|
||||
|
||||
|
@ -27,6 +27,10 @@ PYBIND11_MODULE(KTransformersOps, m) {
|
|||
py::arg("data"), py::arg("blk_size"), py::arg("device"));
|
||||
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.",
|
||||
py::arg("data"), py::arg("blk_size"), py::arg("device"));
|
||||
m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.",
|
||||
py::arg("data"), py::arg("blk_size"), py::arg("device"));
|
||||
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.",
|
||||
py::arg("data"), py::arg("blk_size"), py::arg("device"));
|
||||
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.",
|
||||
py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"),
|
||||
py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"),
|
||||
|
|
|
@ -13,6 +13,7 @@ int test(){
|
|||
|
||||
torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device);
|
||||
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device);
|
||||
torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device);
|
||||
|
||||
PYBIND11_MODULE(cudaops, m) {
|
||||
m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.",
|
||||
|
@ -23,6 +24,10 @@ PYBIND11_MODULE(cudaops, m) {
|
|||
py::arg("data"), py::arg("blk_size"), py::arg("device"));
|
||||
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.",
|
||||
py::arg("data"), py::arg("blk_size"), py::arg("device"));
|
||||
m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.",
|
||||
py::arg("data"), py::arg("blk_size"), py::arg("device"));
|
||||
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.",
|
||||
py::arg("data"), py::arg("blk_size"), py::arg("device"));
|
||||
m.def("test", &test, "Function to test.");
|
||||
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
* @Date : 2024-07-25 13:38:30
|
||||
* @Version : 1.0.0
|
||||
* @LastEditors : kkk1nak0
|
||||
* @LastEditTime : 2024-08-09 07:57:06
|
||||
* @LastEditTime : 2024-08-12 04:18:04
|
||||
* Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c
|
||||
* Copyright (c) 2023-2024 The ggml authors
|
||||
* Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
|
@ -36,6 +36,97 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
|
|||
}
|
||||
}
|
||||
|
||||
__global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
|
||||
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
|
||||
float* __restrict__ output_blk = (float*)(output + block_id * 256);
|
||||
|
||||
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 80)));
|
||||
const float min = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 82)));
|
||||
|
||||
const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16);
|
||||
|
||||
int is = 0;
|
||||
float dl, ml;
|
||||
|
||||
for (int n = 0; n < 256; n += 128) {
|
||||
int shift = 0;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++));
|
||||
uint8_t sc = *scales;
|
||||
dl = d * (sc & 0xF); ml = min * (sc >> 4);
|
||||
for (int l = 0; l < 16; ++l) *output_blk++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
|
||||
|
||||
scales = (uint8_t*)(data + block_id * blk_size + (is++));
|
||||
sc = *scales;
|
||||
|
||||
dl = d * (sc & 0xF); ml = min * (sc >> 4);
|
||||
for (int l = 0; l < 16; ++l) *output_blk++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
|
||||
|
||||
shift += 2;
|
||||
}
|
||||
q += 32;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
|
||||
|
||||
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const uint32_t kmask1 = 0x03030303;
|
||||
const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
|
||||
float* __restrict__ output_blk = (float*)(output + block_id * 256);
|
||||
|
||||
uint32_t aux[4];
|
||||
const int8_t * scales = (const int8_t*)aux;
|
||||
const float d_all = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 108)));
|
||||
|
||||
const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 32);
|
||||
const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0);
|
||||
uint8_t m = 1;
|
||||
|
||||
|
||||
uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96);
|
||||
|
||||
for (int i = 0; i < 3; i++) {
|
||||
aux[i] = 0;
|
||||
for (int j = 0; j < 4; j++) {
|
||||
aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t tmp = aux[2];
|
||||
aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
|
||||
aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
|
||||
aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
|
||||
aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
|
||||
|
||||
int is = 0;
|
||||
float dl;
|
||||
for (int n = 0; n < 256; n += 128) {
|
||||
int shift = 0;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
|
||||
dl = d_all * (scales[is++] - 32);
|
||||
for (int l = 0; l < 16; ++l) {
|
||||
*output_blk++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));
|
||||
}
|
||||
|
||||
dl = d_all * (scales[is++] - 32);
|
||||
for (int l = 0; l < 16; ++l) {
|
||||
*output_blk++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4));
|
||||
}
|
||||
|
||||
shift += 2;
|
||||
m <<= 1;
|
||||
}
|
||||
q += 32;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
__global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
|
||||
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (auto block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
|
||||
|
@ -176,6 +267,24 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
|
|||
return output;
|
||||
}
|
||||
|
||||
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device) {
|
||||
int num_blocks = data.numel() / blk_size;
|
||||
|
||||
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
|
||||
auto data_gpu = torch::empty({data.numel()}, options);
|
||||
|
||||
data_gpu.copy_(data, false);
|
||||
|
||||
// Create output tensor
|
||||
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device));
|
||||
|
||||
// Launch kernel
|
||||
dequantize_q5_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
return output;
|
||||
}
|
||||
|
||||
torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device) {
|
||||
// data.numel%blk_size should be 0, else raise err
|
||||
int num_blocks = data.numel() / blk_size;
|
||||
|
@ -196,8 +305,7 @@ torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device de
|
|||
return output;
|
||||
}
|
||||
|
||||
|
||||
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device) {
|
||||
torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device) {
|
||||
int num_blocks = data.numel() / blk_size;
|
||||
|
||||
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
|
||||
|
@ -209,7 +317,25 @@ torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device de
|
|||
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device));
|
||||
|
||||
// Launch kernel
|
||||
dequantize_q5_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
|
||||
dequantize_q3_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
return output;
|
||||
}
|
||||
|
||||
torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device) {
|
||||
int num_blocks = data.numel() / blk_size;
|
||||
|
||||
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
|
||||
auto data_gpu = torch::empty({data.numel()}, options);
|
||||
|
||||
data_gpu.copy_(data, false);
|
||||
|
||||
// Create output tensor
|
||||
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device));
|
||||
|
||||
// Launch kernel
|
||||
dequantize_q2_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
return output;
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
* @Date : 2024-07-22 09:27:55
|
||||
* @Version : 1.0.0
|
||||
* @LastEditors : kkk1nak0
|
||||
* @LastEditTime : 2024-08-09 01:44:21
|
||||
* @LastEditTime : 2024-08-12 03:48:46
|
||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
**/
|
||||
#pragma once
|
||||
|
@ -17,3 +17,5 @@ torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device de
|
|||
torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device);
|
||||
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device);
|
||||
torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device);
|
||||
torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device);
|
||||
torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device);
|
|
@ -6,7 +6,7 @@ Author : Azure-Tang, Boxin Zhang, chenht2022
|
|||
Date : 2024-07-26 08:48:54
|
||||
Version : 1.0.0
|
||||
LastEditors : kkk1nak0
|
||||
LastEditTime : 2024-08-09 08:03:44
|
||||
LastEditTime : 2024-08-12 07:21:55
|
||||
Adapted from https://github.com/99991/pygguf/blob/main/gguf.py
|
||||
Copyright (c) 2023-2024 The ggml authors
|
||||
Copyright (c) 2024 Thomas Germer
|
||||
|
@ -390,8 +390,14 @@ def dequantize_q2_k(data):
|
|||
|
||||
return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
|
||||
|
||||
def dequantize_q2_k_gpu(data):
|
||||
raise NotImplementedError()
|
||||
def dequantize_q2_k_gpu(data, device:str ="cuda"):
|
||||
block_size = GGML_BLOCK_SIZES["Q2_K"]
|
||||
data = np.frombuffer(data, dtype=data.dtype)
|
||||
device = torch.device(device)
|
||||
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
|
||||
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
|
||||
data = torch.from_numpy(data)
|
||||
return KTransformersOps.dequantize_q2_k(data, block_size, device)
|
||||
|
||||
def dequantize_q3_k(data):
|
||||
# C implementation
|
||||
|
@ -435,8 +441,14 @@ def dequantize_q3_k(data):
|
|||
(((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7])
|
||||
], axis=1)
|
||||
|
||||
def dequantize_q3_k_gpu(data):
|
||||
raise NotImplementedError()
|
||||
def dequantize_q3_k_gpu(data, device:str ="cuda"):
|
||||
block_size = GGML_BLOCK_SIZES["Q3_K"]
|
||||
data = np.frombuffer(data, dtype=data.dtype)
|
||||
device = torch.device(device)
|
||||
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
|
||||
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
|
||||
data = torch.from_numpy(data)
|
||||
return KTransformersOps.dequantize_q3_k(data, block_size, device)
|
||||
|
||||
def dequantize_q4_k(data):
|
||||
# C implementation
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue