mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-05-01 21:21:12 +00:00
Support IQ4_XS dequantize
This commit is contained in:
parent
022b893819
commit
be356c1b8d
5 changed files with 93 additions and 2 deletions
|
|
@ -31,6 +31,8 @@ PYBIND11_MODULE(KTransformersOps, m) {
|
|||
py::arg("data"), py::arg("blk_size"), py::arg("device"));
|
||||
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.",
|
||||
py::arg("data"), py::arg("blk_size"), py::arg("device"));
|
||||
m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.",
|
||||
py::arg("data"), py::arg("blk_size"), py::arg("device"));
|
||||
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.",
|
||||
py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"),
|
||||
py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"),
|
||||
|
|
|
|||
|
|
@ -28,6 +28,8 @@ PYBIND11_MODULE(cudaops, m) {
|
|||
py::arg("data"), py::arg("blk_size"), py::arg("device"));
|
||||
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.",
|
||||
py::arg("data"), py::arg("blk_size"), py::arg("device"));
|
||||
m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.",
|
||||
py::arg("data"), py::arg("blk_size"), py::arg("device"));
|
||||
m.def("test", &test, "Function to test.");
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -212,6 +212,29 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
|
|||
}
|
||||
}
|
||||
|
||||
static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
||||
|
||||
__global__ void dequantize_iq4_xs_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
|
||||
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (auto block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {
|
||||
float* __restrict__ output_blk = (float*)(output + block_id * 256);
|
||||
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size)));
|
||||
const uint16_t scales_h = *(reinterpret_cast<uint16_t*>(data + block_id * blk_size + 2));
|
||||
const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2);
|
||||
const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4);
|
||||
|
||||
for (int ib = 0; ib < 8; ++ib) {
|
||||
const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4);
|
||||
const float dl = d * (ls - 32);
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
output_blk[j + 0] = dl * kvalues_iq4nl[qs[j] & 0xf];
|
||||
output_blk[j + 16] = dl * kvalues_iq4nl[qs[j] >> 4];
|
||||
}
|
||||
output_blk += 32;
|
||||
qs += 16;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device) {
|
||||
int num_blocks = data.numel() / blk_size;
|
||||
|
|
@ -339,4 +362,22 @@ torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device de
|
|||
|
||||
cudaDeviceSynchronize();
|
||||
return output;
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor dequantize_iq4_xs(torch::Tensor data, int blk_size, torch::Device device) {
|
||||
int num_blocks = data.numel() / blk_size;
|
||||
|
||||
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
|
||||
auto data_gpu = torch::empty({data.numel()}, options);
|
||||
|
||||
data_gpu.copy_(data, false);
|
||||
|
||||
// Create output tensor
|
||||
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device));
|
||||
|
||||
// Launch kernel
|
||||
dequantize_iq4_xs_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
return output;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,4 +18,5 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
|
|||
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device);
|
||||
torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device);
|
||||
torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device);
|
||||
torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device);
|
||||
torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device);
|
||||
torch::Tensor dequantize_iq4_xs(torch::Tensor data, int blk_size, torch::Device device);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue