optimize GPU

This commit is contained in:
Atream 2025-02-21 05:06:57 +00:00
parent cf4da5fd47
commit 7e1fe256c8
8 changed files with 677 additions and 156 deletions

View file

@ -20,19 +20,19 @@
PYBIND11_MODULE(KTransformersOps, m) { PYBIND11_MODULE(KTransformersOps, m) {
m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.", m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.", m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.", m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.", m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.", m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.", m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.", m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.", m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.",
py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"), py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"),
py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"), py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"),

View file

@ -17,19 +17,19 @@ torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device de
PYBIND11_MODULE(cudaops, m) { PYBIND11_MODULE(cudaops, m) {
m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.", m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.", m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.", m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.", m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.", m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.", m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.", m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("test", &test, "Function to test."); m.def("test", &test, "Function to test.");
} }

View file

@ -10,19 +10,53 @@
* Copyright (c) 2024 by KVCache.AI, All Rights Reserved. * Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
*/ */
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/library.h> #include <torch/library.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <torch/torch.h> #include <torch/torch.h>
#include <cstdint> #include <cstdint>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
__global__ void dequantize_q8_0_kernel(float* output, const float* scales, const int8_t* qs, int num_blocks, int blk_size) { __global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){ for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){
for(int i=0;i<blk_size;i++){ float* __restrict__ output_blk = (float*)(output + block_id * 256);
float scale = scales[block_id]; const int8_t* cur_block = data + block_id * blk_size;
output[block_id * blk_size + i] = scale * qs[block_id * blk_size + i]; float scale = __half2float(*((half*)cur_block));
cur_block += 2;
for (int i = 0; i < 32; i++){
output_blk[i] = scale * cur_block[i];
} }
output_blk += 32;
}
}
__global__ void dequantize_q8_0_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x) {
__half* __restrict__ output_blk = (__half*)(output + block_id * 256);
const int8_t* cur_block = data + block_id * blk_size;
float scale = __half2float(*((half*)cur_block));
cur_block += 2;
for (int i = 0; i < 32; i++) {
output_blk[i] = __float2half(scale * cur_block[i]);
}
output_blk += 32;
}
}
__global__ void dequantize_q8_0_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x) {
nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * 256);
const int8_t* cur_block = data + block_id * blk_size;
float scale = __half2float(*((half*)cur_block));
cur_block += 2;
for (int i = 0; i < 32; i++) {
output_blk[i] = __float2bfloat16(scale * cur_block[i]);
}
output_blk += 32;
} }
} }
@ -36,13 +70,13 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
} }
} }
__global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { __global__ void dequantize_q2_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){ for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256); float* __restrict__ output_blk = (float*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 80))); const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 80)));
const float min = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 82))); const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 82)));
const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16); const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16);
@ -70,7 +104,75 @@ __global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size
} }
} }
__global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { __global__ void dequantize_q2_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
__half* __restrict__ output_blk = (__half*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 80)));
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 82)));
const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16);
int is = 0;
float dl, ml;
for (int n = 0; n < 256; n += 128) {
int shift = 0;
for (int j = 0; j < 4; ++j) {
uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++));
uint8_t sc = *scales;
dl = d * (sc & 0xF); ml = min * (sc >> 4);
for (int l = 0; l < 16; ++l) *output_blk++ = __float2half(dl * ((int8_t)((q[l] >> shift) & 3)) - ml);
scales = (uint8_t*)(data + block_id * blk_size + (is++));
sc = *scales;
dl = d * (sc & 0xF); ml = min * (sc >> 4);
for (int l = 0; l < 16; ++l) *output_blk++ = __float2half(dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml);
shift += 2;
}
q += 32;
}
}
}
__global__ void dequantize_q2_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 80)));
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 82)));
const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16);
int is = 0;
float dl, ml;
for (int n = 0; n < 256; n += 128) {
int shift = 0;
for (int j = 0; j < 4; ++j) {
uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++));
uint8_t sc = *scales;
dl = d * (sc & 0xF); ml = min * (sc >> 4);
for (int l = 0; l < 16; ++l) *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l] >> shift) & 3)) - ml);
scales = (uint8_t*)(data + block_id * blk_size + (is++));
sc = *scales;
dl = d * (sc & 0xF); ml = min * (sc >> 4);
for (int l = 0; l < 16; ++l) *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml);
shift += 2;
}
q += 32;
}
}
}
__global__ void dequantize_q3_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const uint32_t kmask1 = 0x03030303; const uint32_t kmask1 = 0x03030303;
@ -80,7 +182,7 @@ __global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size
uint32_t aux[4]; uint32_t aux[4];
const int8_t * scales = (const int8_t*)aux; const int8_t * scales = (const int8_t*)aux;
const float d_all = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 108))); const float d_all = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 108)));
const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 32); const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 32);
const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0); const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0);
@ -126,16 +228,128 @@ __global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size
} }
} }
__global__ void dequantize_q3_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) {
__global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const uint32_t kmask1 = 0x03030303;
const uint32_t kmask2 = 0x0f0f0f0f;
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
__half* __restrict__ output_blk = (__half*)(output + block_id * 256);
uint32_t aux[4];
const int8_t * scales = (const int8_t*)aux;
const float d_all = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 108)));
const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 32);
const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0);
uint8_t m = 1;
uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96);
for (int i = 0; i < 3; i++) {
aux[i] = 0;
for (int j = 0; j < 4; j++) {
aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8);
}
}
uint32_t tmp = aux[2];
aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
int is = 0;
float dl;
for (int n = 0; n < 256; n += 128) {
int shift = 0;
for (int j = 0; j < 4; ++j) {
dl = d_all * (scales[is++] - 32);
for (int l = 0; l < 16; ++l) {
*output_blk++ = __float2half(dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4)));
}
dl = d_all * (scales[is++] - 32);
for (int l = 0; l < 16; ++l) {
*output_blk++ = __float2half(dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4)));
}
shift += 2;
m <<= 1;
}
q += 32;
}
}
}
__global__ void dequantize_q3_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const uint32_t kmask1 = 0x03030303;
const uint32_t kmask2 = 0x0f0f0f0f;
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * 256);
uint32_t aux[4];
const int8_t * scales = (const int8_t*)aux;
const float d_all = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 108)));
const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 32);
const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0);
uint8_t m = 1;
uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96);
for (int i = 0; i < 3; i++) {
aux[i] = 0;
for (int j = 0; j < 4; j++) {
aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8);
}
}
uint32_t tmp = aux[2];
aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
int is = 0;
float dl;
for (int n = 0; n < 256; n += 128) {
int shift = 0;
for (int j = 0; j < 4; ++j) {
dl = d_all * (scales[is++] - 32);
for (int l = 0; l < 16; ++l) {
*output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4)));
}
dl = d_all * (scales[is++] - 32);
for (int l = 0; l < 16; ++l) {
*output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4)));
}
shift += 2;
m <<= 1;
}
q += 32;
}
}
}
__global__ void dequantize_q4_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){ for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256); float* __restrict__ output_blk = (float*)(output + block_id * 256);
// const uint8_t * q = data[i].qs; // const uint8_t * q = data[i].qs;
const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16); const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16);
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * 144 + 0))); const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 0)));
const float min = __half2float(*(reinterpret_cast<half*>(data + block_id * 144 + 2))); const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 2)));
int is = 0; int is = 0;
uint8_t sc, m; uint8_t sc, m;
for (int j = 0; j < blk_size; j += 64) { for (int j = 0; j < blk_size; j += 64) {
@ -151,13 +365,61 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size
} }
} }
__global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { __global__ void dequantize_q4_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){
__half* __restrict__ output_blk = (__half*)(output + block_id * 256);
// const uint8_t * q = data[i].qs;
const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 0)));
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 2)));
int is = 0;
uint8_t sc, m;
for (int j = 0; j < blk_size; j += 64) {
uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4);
get_scale_min_k4(is + 0, scales, &sc, &m);
const float d1 = d * sc; const float m1 = min * m;
get_scale_min_k4(is + 1, scales, &sc, &m);
const float d2 = d * sc; const float m2 = min * m;
for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d1 * (q[l] & 0xF) - m1);
for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d2 * (q[l] >> 4) - m2);
q += 32; is += 2;
}
}
}
__global__ void dequantize_q4_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){
nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * 256);
// const uint8_t * q = data[i].qs;
const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 0)));
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 2)));
int is = 0;
uint8_t sc, m;
for (int j = 0; j < blk_size; j += 64) {
uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4);
get_scale_min_k4(is + 0, scales, &sc, &m);
const float d1 = d * sc; const float m1 = min * m;
get_scale_min_k4(is + 1, scales, &sc, &m);
const float d2 = d * sc; const float m2 = min * m;
for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d1 * (q[l] & 0xF) - m1);
for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d2 * (q[l] >> 4) - m2);
q += 32; is += 2;
}
}
}
__global__ void dequantize_q5_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){ for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256); float* __restrict__ output_blk = (float*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 0))); const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 0)));
const float min = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 2))); const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 2)));
const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16); const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16);
const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48); const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48);
@ -180,11 +442,69 @@ __global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size
} }
} }
__global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { __global__ void dequantize_q5_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){
__half* __restrict__ output_blk = (__half*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 0)));
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 2)));
const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16);
const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48);
int is = 0;
uint8_t sc, m;
uint8_t u1 = 1, u2 = 2;
uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4);
for (int j = 0; j < 256; j += 64) {
get_scale_min_k4(is + 0, scales, &sc, &m);
const float d1 = d * sc; const float m1 = min * m;
get_scale_min_k4(is + 1, scales, &sc, &m);
const float d2 = d * sc; const float m2 = min * m;
for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1);
for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2);
ql += 32; is += 2;
u1 <<= 2; u2 <<= 2;
}
}
}
__global__ void dequantize_q5_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){
nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 0)));
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 2)));
const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16);
const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48);
int is = 0;
uint8_t sc, m;
uint8_t u1 = 1, u2 = 2;
uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4);
for (int j = 0; j < 256; j += 64) {
get_scale_min_k4(is + 0, scales, &sc, &m);
const float d1 = d * sc; const float m1 = min * m;
get_scale_min_k4(is + 1, scales, &sc, &m);
const float d2 = d * sc; const float m2 = min * m;
for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1);
for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2);
ql += 32; is += 2;
u1 <<= 2; u2 <<= 2;
}
}
}
__global__ void dequantize_q6_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){ for (long long block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256); float* __restrict__ output_blk = (float*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 208))); const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 208)));
const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size); const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size);
const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128); const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128);
@ -212,14 +532,78 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
} }
} }
__global__ void dequantize_q6_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
__half* __restrict__ output_blk = (__half*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 208)));
const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size);
const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128);
const int8_t * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192);
//if (blk_size == 256){
for (int n = 0; n < blk_size; n += 128) {
for (int l = 0; l < 32; ++l) {
int is = l/16;
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
output_blk[l + 0] = __float2half(d * sc[is + 0] * q1);
output_blk[l + 32] = __float2half(d * sc[is + 2] * q2);
output_blk[l + 64] = __float2half(d * sc[is + 4] * q3);
output_blk[l + 96] = __float2half(d * sc[is + 6] * q4);
}
output_blk += 128;
ql += 64;
qh += 32;
sc += 8;
}
}
}
__global__ void dequantize_q6_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 208)));
const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size);
const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128);
const int8_t * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192);
//if (blk_size == 256){
for (int n = 0; n < blk_size; n += 128) {
for (int l = 0; l < 32; ++l) {
int is = l/16;
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
output_blk[l + 0] = __float2bfloat16(d * sc[is + 0] * q1);
output_blk[l + 32] = __float2bfloat16(d * sc[is + 2] * q2);
output_blk[l + 64] = __float2bfloat16(d * sc[is + 4] * q3);
output_blk[l + 96] = __float2bfloat16(d * sc[is + 6] * q4);
}
output_blk += 128;
ql += 64;
qh += 32;
sc += 8;
}
}
}
static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
__global__ void dequantize_iq4_xs_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { __global__ void dequantize_iq4_xs_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) { for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {
float* __restrict__ output_blk = (float*)(output + block_id * 256); float* __restrict__ output_blk = (float*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size))); const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size)));
const uint16_t scales_h = *(reinterpret_cast<uint16_t*>(data + block_id * blk_size + 2)); const uint16_t scales_h = *(reinterpret_cast<const uint16_t*>(data + block_id * blk_size + 2));
const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2); const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2);
const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4); const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4);
@ -236,152 +620,267 @@ __global__ void dequantize_iq4_xs_kernel(int8_t* data, float* output, int blk_si
} }
} }
torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device) { __global__ void dequantize_iq4_xs_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) {
int num_blocks = data.numel() / blk_size; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {
__half* __restrict__ output_blk = (__half*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size)));
const uint16_t scales_h = *(reinterpret_cast<const uint16_t*>(data + block_id * blk_size + 2));
const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2);
const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4);
for (int ib = 0; ib < 8; ++ib) {
const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4);
const float dl = d * (ls - 32);
for (int j = 0; j < 16; ++j) {
output_blk[j + 0] = __float2half(dl * kvalues_iq4nl[qs[j] & 0xf]);
output_blk[j + 16] = __float2half(dl * kvalues_iq4nl[qs[j] >> 4]);
}
output_blk += 32;
qs += 16;
}
}
}
__global__ void dequantize_iq4_xs_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {
nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size)));
const uint16_t scales_h = *(reinterpret_cast<const uint16_t*>(data + block_id * blk_size + 2));
const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2);
const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4);
for (int ib = 0; ib < 8; ++ib) {
const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4);
const float dl = d * (ls - 32);
for (int j = 0; j < 16; ++j) {
output_blk[j + 0] = __float2bfloat16(dl * kvalues_iq4nl[qs[j] & 0xf]);
output_blk[j + 16] = __float2bfloat16(dl * kvalues_iq4nl[qs[j] >> 4]);
}
output_blk += 32;
qs += 16;
}
}
}
torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) {
int num_blocks = num_bytes / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device); const at::cuda::OptionalCUDAGuard device_guard(device);
// create gpu
auto options_scales = torch::TensorOptions().dtype(torch::kFloat32).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto options_qs = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto scales_gpu = torch::empty({{num_blocks, 1}}, options_scales);
auto qs_gpu = torch::empty({num_blocks, 32}, options_qs);
// read on cpu auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
options_scales = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCPU); auto data_gpu = torch::empty({ num_bytes }, options);
options_qs = torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU);
// // reinterpret cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);
auto scales = torch::from_blob(data.data_ptr(), {num_blocks, 1 + 16}, options_scales).slice(1, 0, 1); //data_gpu.copy_(data, false);
auto qs = torch::from_blob(data.data_ptr(), {num_blocks, 2 + 32}, options_qs).slice(1, 2);
auto scales_f32 = scales.to(torch::kFloat32);
scales_gpu.copy_(scales_f32, false);
qs_gpu.copy_(qs, false);
// Create output tensor // Create output tensor
auto output = torch::zeros_like(qs, torch::dtype(torch::kFloat32).device(device)); auto output = torch::zeros({ num_blocks, 32 }, torch::dtype(target_dtype).device(device));
// Launch kernel switch (target_dtype) {
dequantize_q8_0_kernel<<< 512, 256 >>>( case torch::kFloat16:
output.data_ptr<float>(), scales_gpu.data_ptr<float>(), qs_gpu.data_ptr<int8_t>(), num_blocks, 32); dequantize_q8_0_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, num_blocks);
break;
case torch::kBFloat16:
dequantize_q8_0_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<nv_bfloat16>(), blk_size, num_blocks);
break;
case torch::kFloat32:
dequantize_q8_0_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
break;
default:
printf("target type not support\n");
exit(0);
}
cudaDeviceSynchronize(); cudaDeviceSynchronize();
return output; return output;
} }
torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device) { torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) {
// data.numel%blk_size should be 0, else raise err // data.numel%blk_size should be 0, else raise err
int num_blocks = data.numel() / blk_size; int num_blocks = num_bytes / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device); const at::cuda::OptionalCUDAGuard device_guard(device);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options); auto data_gpu = torch::empty({num_bytes}, options);
data_gpu.copy_(data, false); cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);
//data_gpu.copy_(data, false);
// Create output tensor // Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device)); auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
// Launch kernel
dequantize_q6_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
// dequantize_q6_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), 256, num_blocks);
switch (target_dtype) {
case torch::kFloat16:
dequantize_q6_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, num_blocks);
break;
case torch::kBFloat16:
dequantize_q6_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<nv_bfloat16>(), blk_size, num_blocks);
break;
case torch::kFloat32:
dequantize_q6_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
break;
default:
printf("target type not support\n");
exit(0);
}
cudaDeviceSynchronize(); cudaDeviceSynchronize();
return output; return output;
} }
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device) { torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) {
int num_blocks = data.numel() / blk_size; int num_blocks = num_bytes / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device); const at::cuda::OptionalCUDAGuard device_guard(device);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options); auto data_gpu = torch::empty({num_bytes}, options);
data_gpu.copy_(data, false); cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);
//data_gpu.copy_(data, false);
// Create output tensor // Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device)); auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
// Launch kernel
dequantize_q5_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
switch (target_dtype) {
case torch::kFloat16:
dequantize_q5_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, num_blocks);
break;
case torch::kBFloat16:
dequantize_q5_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<nv_bfloat16>(), blk_size, num_blocks);
break;
case torch::kFloat32:
dequantize_q5_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
break;
default:
printf("target type not support\n");
exit(0);
}
cudaDeviceSynchronize(); cudaDeviceSynchronize();
return output; return output;
} }
torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device) { torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) {
// data.numel%blk_size should be 0, else raise err // data.numel%blk_size should be 0, else raise err
int num_blocks = data.numel() / blk_size; int num_blocks = num_bytes / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device); const at::cuda::OptionalCUDAGuard device_guard(device);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options); auto data_gpu = torch::empty({num_bytes}, options);
data_gpu.copy_(data, false); cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);
//data_gpu.copy_(data, false);
// Create output tensor // Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device)); auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
// Launch kernel
dequantize_q4_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), 256, num_blocks);
switch (target_dtype) {
case torch::kFloat16:
dequantize_q4_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, num_blocks);
break;
case torch::kBFloat16:
dequantize_q4_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<nv_bfloat16>(), blk_size, num_blocks);
break;
case torch::kFloat32:
dequantize_q4_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
break;
default:
printf("target type not support\n");
exit(0);
}
cudaDeviceSynchronize(); cudaDeviceSynchronize();
return output; return output;
} }
torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device) { torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) {
int num_blocks = data.numel() / blk_size; int num_blocks = num_bytes / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device); const at::cuda::OptionalCUDAGuard device_guard(device);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options); auto data_gpu = torch::empty({num_bytes}, options);
data_gpu.copy_(data, false); cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);
//data_gpu.copy_(data, false);
// Create output tensor // Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device)); auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
// Launch kernel
dequantize_q3_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
switch (target_dtype) {
case torch::kFloat16:
dequantize_q3_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, num_blocks);
break;
case torch::kBFloat16:
dequantize_q3_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<nv_bfloat16>(), blk_size, num_blocks);
break;
case torch::kFloat32:
dequantize_q3_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
break;
default:
printf("target type not support\n");
exit(0);
}
cudaDeviceSynchronize(); cudaDeviceSynchronize();
return output; return output;
} }
torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device) { torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) {
int num_blocks = data.numel() / blk_size; int num_blocks = num_bytes / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device); const at::cuda::OptionalCUDAGuard device_guard(device);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options); auto data_gpu = torch::empty({num_bytes}, options);
data_gpu.copy_(data, false); cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);
//data_gpu.copy_(data, false);
// Create output tensor // Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device)); auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
// Launch kernel
dequantize_q2_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
switch (target_dtype) {
case torch::kFloat16:
dequantize_q2_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, num_blocks);
break;
case torch::kBFloat16:
dequantize_q2_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<nv_bfloat16>(), blk_size, num_blocks);
break;
case torch::kFloat32:
dequantize_q2_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
break;
default:
printf("target type not support\n");
exit(0);
}
cudaDeviceSynchronize(); cudaDeviceSynchronize();
return output; return output;
} }
torch::Tensor dequantize_iq4_xs(torch::Tensor data, int blk_size, torch::Device device) { torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) {
int num_blocks = data.numel() / blk_size; int num_blocks = num_bytes / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device); const at::cuda::OptionalCUDAGuard device_guard(device);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options); auto data_gpu = torch::empty({num_bytes}, options);
data_gpu.copy_(data, false); cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);
//data_gpu.copy_(data, false);
// Create output tensor // Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device)); auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
// Launch kernel
dequantize_iq4_xs_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
switch (target_dtype) {
case torch::kFloat16:
dequantize_iq4_xs_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, num_blocks);
break;
case torch::kBFloat16:
dequantize_iq4_xs_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<nv_bfloat16>(), blk_size, num_blocks);
break;
case torch::kFloat32:
dequantize_iq4_xs_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
break;
default:
printf("target type not support\n");
exit(0);
}
cudaDeviceSynchronize(); cudaDeviceSynchronize();
return output; return output;
} }

View file

@ -13,10 +13,10 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <torch/torch.h> #include <torch/torch.h>
torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_iq4_xs(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype);

View file

@ -168,9 +168,6 @@ def local_chat(
if mode == 'long_context': if mode == 'long_context':
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
"please change max_seq_len in ~/.ktransformers/config.yaml" "please change max_seq_len in ~/.ktransformers/config.yaml"
torch.set_default_dtype(
torch.bfloat16
) # TODO: Remove this, replace dtype using config
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled: if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled:
generated = prefill_and_generate( generated = prefill_and_generate(

View file

@ -5,6 +5,18 @@
kwargs: kwargs:
generate_device: "cuda" generate_device: "cuda"
prefill_device: "cuda" prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously

View file

@ -25,10 +25,10 @@ class KTransformersThreadContext(TransformersThreadContext):
class KTransformersInterface(TransformersInterface): class KTransformersInterface(TransformersInterface):
def __init__(self, args: ConfigArgs = default_args): def __init__(self, args: ConfigArgs = default_args):
self.args = args self.args = args
torch.set_default_dtype(torch.bfloat16)
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code) self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code) config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
torch.set_default_dtype(config.torch_dtype)
if config.architectures[0] == "Qwen2MoeForCausalLM": if config.architectures[0] == "Qwen2MoeForCausalLM":
config._attn_implementation = "flash_attention_2" config._attn_implementation = "flash_attention_2"

View file

@ -285,7 +285,7 @@ class GGUFLoader:
itemsize = int(np.empty([], dtype = item_type).itemsize) itemsize = int(np.empty([], dtype = item_type).itemsize)
return mmap_data[offset : offset + itemsize * item_count] return mmap_data[offset : offset + itemsize * item_count]
def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "gpu")->torch.Tensor: def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "cuda", target_dtype = torch.get_default_dtype())->torch.Tensor:
t = self.tensor_info[name] t = self.tensor_info[name]
if device.lower() == "cpu": if device.lower() == "cpu":
print(f"loading expert {expert_id} of {name} with CPU") print(f"loading expert {expert_id} of {name} with CPU")
@ -304,7 +304,7 @@ class GGUFLoader:
data = data[offset: offset + block_size * blocks_per_experts] data = data[offset: offset + block_size * blocks_per_experts]
if "cuda" in device.lower(): if "cuda" in device.lower():
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) values = GGML_DEQUANTIZE_GPU[ggml_name](data, device, target_dtype)
else: else:
values = GGML_DEQUANTIZE[ggml_name](data) values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values) values = torch.from_numpy(values)
@ -313,7 +313,7 @@ class GGUFLoader:
return values return values
def load_gguf_tensor(self, name: str, device:str = "cpu")->torch.Tensor: def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = torch.get_default_dtype())->torch.Tensor:
t = self.tensor_info[name] t = self.tensor_info[name]
if device.lower() == "cpu": if device.lower() == "cpu":
print(f"loading {name} with CPU") print(f"loading {name} with CPU")
@ -328,16 +328,36 @@ class GGUFLoader:
data = self.get_mmap_tensor(name) data = self.get_mmap_tensor(name)
if "cuda" in device.lower(): block_size = GGML_BLOCK_SIZES[ggml_name]
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) elements_per_block = GGML_ELEMENTS_PER_BLOCK[ggml_name]
#values = GGML_DEQUANTIZE[ggml_name](data) num_elements = int(np.prod(shape))
#print("load_gguf_tensor") num_blocks = num_elements // elements_per_block
#values = torch.from_numpy(values).to(device = device)
blocks_per_iter = 16384
if num_blocks > blocks_per_iter: # dequant large tensor
values = torch.empty((num_blocks, elements_per_block), dtype=torch.float, device=device)
for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter):
blocks_begin = i * blocks_per_iter
blocks_end = min(blocks_begin + blocks_per_iter, num_blocks)
if "cuda" in device.lower():
cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype)
else:
cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size])
cur_values = torch.from_numpy(cur_values)
cur_values = cur_values.view(-1, elements_per_block)
values[blocks_begin : blocks_end] = cur_values
else: else:
values = GGML_DEQUANTIZE[ggml_name](data) if "cuda" in device.lower():
values = torch.from_numpy(values) values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
else:
values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values)
if ggml_name == "BF16": if ggml_name == "BF16":
values = values.view(torch.bfloat16) values = values.view(torch.bfloat16)
values = values.view(shape[::-1]) values = values.view(shape[::-1])
if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]: if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
n_head = self.gguf_file_meta['llama.attention.head_count'] n_head = self.gguf_file_meta['llama.attention.head_count']
@ -433,14 +453,13 @@ def dequantize_q2_k(data):
return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4) return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
def dequantize_q2_k_gpu(data, device:str ="cuda"): def dequantize_q2_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q2_K"] block_size = GGML_BLOCK_SIZES["Q2_K"]
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device) device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data) return KTransformersOps.dequantize_q2_k(data.data, data.size, block_size, device, target_dtype)
return KTransformersOps.dequantize_q2_k(data, block_size, device)
def dequantize_q3_k(data): def dequantize_q3_k(data):
# C implementation # C implementation
@ -484,14 +503,13 @@ def dequantize_q3_k(data):
(((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7]) (((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7])
], axis=1) ], axis=1)
def dequantize_q3_k_gpu(data, device:str ="cuda"): def dequantize_q3_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q3_K"] block_size = GGML_BLOCK_SIZES["Q3_K"]
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device) device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data) return KTransformersOps.dequantize_q3_k(data.data, data.size, block_size, device, target_dtype)
return KTransformersOps.dequantize_q3_k(data, block_size, device)
def dequantize_q4_k(data): def dequantize_q4_k(data):
# C implementation # C implementation
@ -515,13 +533,12 @@ def dequantize_q4_k(data):
# Dequantize final weights using scales and offsets # Dequantize final weights using scales and offsets
return factors * qs2 - offsets return factors * qs2 - offsets
def dequantize_q4_k_gpu(data, device:str ="cuda"): def dequantize_q4_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device) device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data) return KTransformersOps.dequantize_q4_k(data.data, data.size, 144, device, target_dtype)
return KTransformersOps.dequantize_q4_k(data, 144, device)
def dequantize_q5_k(data): def dequantize_q5_k(data):
# C implementation # C implementation
@ -579,14 +596,13 @@ def dequantize_q5_k(data):
d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8, d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,
], axis=1) ], axis=1)
def dequantize_q5_k_gpu(data, device:str ="cuda"): def dequantize_q5_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q5_K"] block_size = GGML_BLOCK_SIZES["Q5_K"]
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device) device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data) return KTransformersOps.dequantize_q5_k(data.data, data.size, block_size, device, target_dtype)
return KTransformersOps.dequantize_q5_k(data, block_size, device)
def dequantize_q6_k(data): def dequantize_q6_k(data):
# C implementation # C implementation
@ -637,13 +653,12 @@ def dequantize_q6_k(data):
], axis=1) ], axis=1)
# @torch.jit.script # @torch.jit.script
def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda"): def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q6_K"] block_size = GGML_BLOCK_SIZES["Q6_K"]
device = torch.device(device) device = torch.device(device)
num_blocks = len(data) // block_size num_blocks = len(data) // block_size
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
data = torch.from_numpy(data) return KTransformersOps.dequantize_q6_k(data.data, data.size, block_size, device, target_dtype)
return KTransformersOps.dequantize_q6_k(data, block_size, device)
kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8) kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8)
@ -677,13 +692,12 @@ def dequantize_iq4_xs(data):
return y.flatten() return y.flatten()
def dequantize_iq4_xs_gpu(data: np.ndarray, device:str = "cuda"): def dequantize_iq4_xs_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["IQ4_XS"] block_size = GGML_BLOCK_SIZES["IQ4_XS"]
device = torch.device(device) device = torch.device(device)
num_blocks = len(data) // block_size num_blocks = len(data) // block_size
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
data = torch.from_numpy(data) return KTransformersOps.dequantize_iq4_xs(data.data, data.size, block_size, device, target_dtype)
return KTransformersOps.dequantize_iq4_xs(data, block_size, device)
def dequantize_q4_0(data): def dequantize_q4_0(data):
# C implementation # C implementation
@ -700,7 +714,7 @@ def dequantize_q4_0(data):
scales * ((qs >> 4).astype(np.int8) - 8), scales * ((qs >> 4).astype(np.int8) - 8),
], axis=1) ], axis=1)
def dequantize_q4_0_gpu(data): def dequantize_q4_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()):
raise NotImplementedError() raise NotImplementedError()
def dequantize_q5_0(data): def dequantize_q5_0(data):
@ -724,7 +738,7 @@ def dequantize_q5_0(data):
scales * x1, scales * x1,
], axis=1) ], axis=1)
def dequantize_q5_0_gpu(data): def dequantize_q5_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()):
raise NotImplementedError() raise NotImplementedError()
def dequantize_q8_0(data): def dequantize_q8_0(data):
@ -736,20 +750,19 @@ def dequantize_q8_0(data):
qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:] qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:]
return scales * qs return scales * qs
def dequantize_q8_0_gpu(data, device:str = "cuda"): def dequantize_q8_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()):
# C struct definition # C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43 # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43
num_blocks = len(data) // GGML_BLOCK_SIZES["Q8_0"] num_blocks = len(data) // GGML_BLOCK_SIZES["Q8_0"]
device = torch.device(device) device = torch.device(device)
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
data = torch.from_numpy(data) return KTransformersOps.dequantize_q8_0(data.data, data.size, 34, device, target_dtype)
return KTransformersOps.dequantize_q8_0(data, 34, device)
def dequantize_f32(data): def dequantize_f32(data):
return np.frombuffer(data, dtype=np.float32) return np.frombuffer(data, dtype=np.float32)
def dequantize_f32_gpu(data, device): def dequantize_f32_gpu(data, device, target_dtype = torch.get_default_dtype()):
data = np.frombuffer(data, dtype=np.float32) data = np.frombuffer(data, dtype=np.float32)
res = torch.from_numpy(data) res = torch.from_numpy(data)
res_gpu = torch.empty_like(res, device=device) res_gpu = torch.empty_like(res, device=device)
@ -759,7 +772,7 @@ def dequantize_f32_gpu(data, device):
def dequantize_f16(data): def dequantize_f16(data):
return np.frombuffer(data, dtype=np.float16) return np.frombuffer(data, dtype=np.float16)
def dequantize_f16_gpu(data, device): def dequantize_f16_gpu(data, device, target_dtype = torch.get_default_dtype()):
data = np.frombuffer(data, dtype=np.float16) data = np.frombuffer(data, dtype=np.float16)
res = torch.from_numpy(data) res = torch.from_numpy(data)
res_gpu = torch.empty_like(res, device=device) res_gpu = torch.empty_like(res, device=device)