Support IQ4_XS dequantize

This commit is contained in:
Yap Sok Ann 2024-09-02 09:03:42 +07:00
parent 022b893819
commit be356c1b8d
5 changed files with 93 additions and 2 deletions

View file

@ -31,6 +31,8 @@ PYBIND11_MODULE(KTransformersOps, m) {
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.", m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.", m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.",
py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"), py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"),
py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"), py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"),

View file

@ -28,6 +28,8 @@ PYBIND11_MODULE(cudaops, m) {
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.", m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("test", &test, "Function to test."); m.def("test", &test, "Function to test.");
} }

View file

@ -212,6 +212,29 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
} }
} }
static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
__global__ void dequantize_iq4_xs_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {
float* __restrict__ output_blk = (float*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size)));
const uint16_t scales_h = *(reinterpret_cast<uint16_t*>(data + block_id * blk_size + 2));
const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2);
const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4);
for (int ib = 0; ib < 8; ++ib) {
const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4);
const float dl = d * (ls - 32);
for (int j = 0; j < 16; ++j) {
output_blk[j + 0] = dl * kvalues_iq4nl[qs[j] & 0xf];
output_blk[j + 16] = dl * kvalues_iq4nl[qs[j] >> 4];
}
output_blk += 32;
qs += 16;
}
}
}
torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device) { torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device) {
int num_blocks = data.numel() / blk_size; int num_blocks = data.numel() / blk_size;
@ -339,4 +362,22 @@ torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device de
cudaDeviceSynchronize(); cudaDeviceSynchronize();
return output; return output;
} }
torch::Tensor dequantize_iq4_xs(torch::Tensor data, int blk_size, torch::Device device) {
int num_blocks = data.numel() / blk_size;
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options);
data_gpu.copy_(data, false);
// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device));
// Launch kernel
dequantize_iq4_xs_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
cudaDeviceSynchronize();
return output;
}

View file

@ -18,4 +18,5 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_iq4_xs(torch::Tensor data, int blk_size, torch::Device device);

View file

@ -108,6 +108,7 @@ GGML_TYPES = {
"Q4_K": 12, "Q4_K": 12,
"Q5_K": 13, "Q5_K": 13,
"Q6_K": 14, "Q6_K": 14,
"IQ4_XS": 23,
} }
GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()} GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()}
@ -123,6 +124,7 @@ GGML_BLOCK_SIZES = {
"Q4_K": 2 + 2 + 12 + 256 // 2, "Q4_K": 2 + 2 + 12 + 256 // 2,
"Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2, "Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2,
"Q6_K": 256 // 2 + 256 // 4 + 256 // 16 + 2, "Q6_K": 256 // 2 + 256 // 4 + 256 // 16 + 2,
"IQ4_XS": 2 + 2 + 256 // 2 + 256 // 64,
} }
GGML_ELEMENTS_PER_BLOCK = { GGML_ELEMENTS_PER_BLOCK = {
@ -136,6 +138,7 @@ GGML_ELEMENTS_PER_BLOCK = {
"Q4_K": 256, "Q4_K": 256,
"Q5_K": 256, "Q5_K": 256,
"Q6_K": 256, "Q6_K": 256,
"IQ4_XS": 256,
} }
DATA_TYPES = { DATA_TYPES = {
@ -601,6 +604,46 @@ def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda"):
data = torch.from_numpy(data) data = torch.from_numpy(data)
return KTransformersOps.dequantize_q6_k(data, block_size, device) return KTransformersOps.dequantize_q6_k(data, block_size, device)
kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8)
def dequantize_iq4_xs(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/21d3a308fcb7f31cb9beceaeebad4fb622f3c337/src/ggml-quants.c#L3568
# C struct definition
# https://github.com/ggerganov/ggml/blob/21d3a308fcb7f31cb9beceaeebad4fb622f3c337/src/ggml-common.h#L393
block_size = GGML_BLOCK_SIZES["IQ4_XS"]
num_blocks = len(data) // block_size
d = np.frombuffer(data, dtype=np.float16)[0::block_size//2].astype(np.float32).reshape(num_blocks, 1)
scales_h = np.frombuffer(data, dtype=np.uint16)[1::block_size//2].reshape(num_blocks, 1)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)[:, 4:]
scales_l = data_u8[:, :4].reshape(num_blocks, 4)
qs = data_u8[:, 4:].reshape(num_blocks, block_size - 8)
ls = np.zeros((num_blocks, QK_K // 32), dtype=np.int8)
for ib in range(QK_K // 32):
ls[:, ib] = ((scales_l[:, ib // 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h[:, 0] >> 2 * ib) & 3) << 4)
dl = (d * (ls - 32)).reshape(num_blocks, -1, 1)
qs_lo_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) & 0xf
qs_hi_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) >> 4
y = np.zeros((num_blocks, QK_K), dtype=np.float32)
for ib in range(QK_K // 32):
y[:, ib*32:(ib*32)+16] = dl[:, ib] * kvalues_iq4nl[qs_lo_4[:, ib]]
y[:, (ib*32)+16:(ib*32)+32] = dl[:, ib] * kvalues_iq4nl[qs_hi_4[:, ib]]
return y.flatten()
def dequantize_iq4_xs_gpu(data: np.ndarray, device:str = "cuda"):
block_size = GGML_BLOCK_SIZES["IQ4_XS"]
device = torch.device(device)
num_blocks = len(data) // block_size
data = np.frombuffer(data, dtype=data.dtype)
data = torch.from_numpy(data)
return KTransformersOps.dequantize_iq4_xs(data, block_size, device)
def dequantize_q4_0(data): def dequantize_q4_0(data):
# C implementation # C implementation
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1515 # https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1515
@ -693,6 +736,7 @@ GGML_DEQUANTIZE = {
"Q4_K": dequantize_q4_k, "Q4_K": dequantize_q4_k,
"Q5_K": dequantize_q5_k, "Q5_K": dequantize_q5_k,
"Q6_K": dequantize_q6_k, "Q6_K": dequantize_q6_k,
"IQ4_XS": dequantize_iq4_xs,
} }
GGML_DEQUANTIZE_GPU = { GGML_DEQUANTIZE_GPU = {
@ -706,6 +750,7 @@ GGML_DEQUANTIZE_GPU = {
"Q4_K": dequantize_q4_k_gpu, "Q4_K": dequantize_q4_k_gpu,
"Q5_K": dequantize_q5_k_gpu, "Q5_K": dequantize_q5_k_gpu,
"Q6_K": dequantize_q6_k_gpu, "Q6_K": dequantize_q6_k_gpu,
"IQ4_XS": dequantize_iq4_xs_gpu,
} }