diff --git a/ktransformers/ktransformers_ext/cuda/binding.cpp b/ktransformers/ktransformers_ext/cuda/binding.cpp index 06ec5f3..65c8bc4 100644 --- a/ktransformers/ktransformers_ext/cuda/binding.cpp +++ b/ktransformers/ktransformers_ext/cuda/binding.cpp @@ -31,6 +31,8 @@ PYBIND11_MODULE(KTransformersOps, m) { py::arg("data"), py::arg("blk_size"), py::arg("device")); m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.", py::arg("data"), py::arg("blk_size"), py::arg("device")); + m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.", + py::arg("data"), py::arg("blk_size"), py::arg("device")); m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.", py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"), py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"), diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp b/ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp index 70fc606..99069d8 100644 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp +++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp @@ -28,6 +28,8 @@ PYBIND11_MODULE(cudaops, m) { py::arg("data"), py::arg("blk_size"), py::arg("device")); m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.", py::arg("data"), py::arg("blk_size"), py::arg("device")); + m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.", + py::arg("data"), py::arg("blk_size"), py::arg("device")); m.def("test", &test, "Function to test."); } diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu b/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu index cc5552b..1583cf7 100644 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu +++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu @@ -212,6 +212,29 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size } } +static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; + +__global__ void dequantize_iq4_xs_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { + int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (auto block_id=global_idx; block_id(data + block_id * blk_size))); + const uint16_t scales_h = *(reinterpret_cast(data + block_id * blk_size + 2)); + const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2); + const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4); + + for (int ib = 0; ib < 8; ++ib) { + const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4); + const float dl = d * (ls - 32); + for (int j = 0; j < 16; ++j) { + output_blk[j + 0] = dl * kvalues_iq4nl[qs[j] & 0xf]; + output_blk[j + 16] = dl * kvalues_iq4nl[qs[j] >> 4]; + } + output_blk += 32; + qs += 16; + } + } +} torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device) { int num_blocks = data.numel() / blk_size; @@ -339,4 +362,22 @@ torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device de cudaDeviceSynchronize(); return output; -} \ No newline at end of file +} + +torch::Tensor dequantize_iq4_xs(torch::Tensor data, int blk_size, torch::Device device) { + int num_blocks = data.numel() / blk_size; + + auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); + auto data_gpu = torch::empty({data.numel()}, options); + + data_gpu.copy_(data, false); + + // Create output tensor + auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device)); + + // Launch kernel + dequantize_iq4_xs_kernel<<< 512, 256 >>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + + cudaDeviceSynchronize(); + return output; +} diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h b/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h index 5196f88..666d455 100644 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h +++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h @@ -18,4 +18,5 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device); -torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device); \ No newline at end of file +torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device); +torch::Tensor dequantize_iq4_xs(torch::Tensor data, int blk_size, torch::Device device); diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index 04ce0ae..a90c0ed 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -108,6 +108,7 @@ GGML_TYPES = { "Q4_K": 12, "Q5_K": 13, "Q6_K": 14, + "IQ4_XS": 23, } GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()} @@ -123,6 +124,7 @@ GGML_BLOCK_SIZES = { "Q4_K": 2 + 2 + 12 + 256 // 2, "Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2, "Q6_K": 256 // 2 + 256 // 4 + 256 // 16 + 2, + "IQ4_XS": 2 + 2 + 256 // 2 + 256 // 64, } GGML_ELEMENTS_PER_BLOCK = { @@ -136,6 +138,7 @@ GGML_ELEMENTS_PER_BLOCK = { "Q4_K": 256, "Q5_K": 256, "Q6_K": 256, + "IQ4_XS": 256, } DATA_TYPES = { @@ -601,6 +604,46 @@ def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda"): data = torch.from_numpy(data) return KTransformersOps.dequantize_q6_k(data, block_size, device) +kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8) + +def dequantize_iq4_xs(data): + # C implementation + # https://github.com/ggerganov/ggml/blob/21d3a308fcb7f31cb9beceaeebad4fb622f3c337/src/ggml-quants.c#L3568 + # C struct definition + # https://github.com/ggerganov/ggml/blob/21d3a308fcb7f31cb9beceaeebad4fb622f3c337/src/ggml-common.h#L393 + block_size = GGML_BLOCK_SIZES["IQ4_XS"] + num_blocks = len(data) // block_size + + d = np.frombuffer(data, dtype=np.float16)[0::block_size//2].astype(np.float32).reshape(num_blocks, 1) + scales_h = np.frombuffer(data, dtype=np.uint16)[1::block_size//2].reshape(num_blocks, 1) + data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)[:, 4:] + scales_l = data_u8[:, :4].reshape(num_blocks, 4) + qs = data_u8[:, 4:].reshape(num_blocks, block_size - 8) + + ls = np.zeros((num_blocks, QK_K // 32), dtype=np.int8) + for ib in range(QK_K // 32): + ls[:, ib] = ((scales_l[:, ib // 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h[:, 0] >> 2 * ib) & 3) << 4) + + dl = (d * (ls - 32)).reshape(num_blocks, -1, 1) + + qs_lo_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) & 0xf + qs_hi_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) >> 4 + + y = np.zeros((num_blocks, QK_K), dtype=np.float32) + for ib in range(QK_K // 32): + y[:, ib*32:(ib*32)+16] = dl[:, ib] * kvalues_iq4nl[qs_lo_4[:, ib]] + y[:, (ib*32)+16:(ib*32)+32] = dl[:, ib] * kvalues_iq4nl[qs_hi_4[:, ib]] + + return y.flatten() + +def dequantize_iq4_xs_gpu(data: np.ndarray, device:str = "cuda"): + block_size = GGML_BLOCK_SIZES["IQ4_XS"] + device = torch.device(device) + num_blocks = len(data) // block_size + data = np.frombuffer(data, dtype=data.dtype) + data = torch.from_numpy(data) + return KTransformersOps.dequantize_iq4_xs(data, block_size, device) + def dequantize_q4_0(data): # C implementation # https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1515 @@ -693,6 +736,7 @@ GGML_DEQUANTIZE = { "Q4_K": dequantize_q4_k, "Q5_K": dequantize_q5_k, "Q6_K": dequantize_q6_k, + "IQ4_XS": dequantize_iq4_xs, } GGML_DEQUANTIZE_GPU = { @@ -706,6 +750,7 @@ GGML_DEQUANTIZE_GPU = { "Q4_K": dequantize_q4_k_gpu, "Q5_K": dequantize_q5_k_gpu, "Q6_K": dequantize_q6_k_gpu, + "IQ4_XS": dequantize_iq4_xs_gpu, }