From 7e1fe256c8c7235e620671e76a72c84a2ddb27e3 Mon Sep 17 00:00:00 2001 From: Atream Date: Fri, 21 Feb 2025 05:06:57 +0000 Subject: [PATCH] optimize GPU --- .../ktransformers_ext/cuda/binding.cpp | 14 +- .../cuda/custom_gguf/binding.cpp | 22 +- .../cuda/custom_gguf/dequant.cu | 681 +++++++++++++++--- .../ktransformers_ext/cuda/custom_gguf/ops.h | 14 +- ktransformers/local_chat.py | 5 +- .../optimize_rules/DeepSeek-V3-Chat.yaml | 12 + .../backend/interfaces/ktransformers.py | 2 +- ktransformers/util/custom_gguf.py | 83 ++- 8 files changed, 677 insertions(+), 156 deletions(-) diff --git a/ktransformers/ktransformers_ext/cuda/binding.cpp b/ktransformers/ktransformers_ext/cuda/binding.cpp index 65c8bc4..1f89b31 100644 --- a/ktransformers/ktransformers_ext/cuda/binding.cpp +++ b/ktransformers/ktransformers_ext/cuda/binding.cpp @@ -20,19 +20,19 @@ PYBIND11_MODULE(KTransformersOps, m) { m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.", py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"), py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"), diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp b/ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp index 99069d8..2011247 100644 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp +++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp @@ -17,19 +17,19 @@ torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device de PYBIND11_MODULE(cudaops, m) { m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); - m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); - m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); - m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); - m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); + m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.", + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); + m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.", + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); + m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.", + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); + m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.", + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); m.def("test", &test, "Function to test."); } diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu b/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu index d5b4a2c..d5184ce 100644 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu +++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu @@ -10,19 +10,53 @@ * Copyright (c) 2024 by KVCache.AI, All Rights Reserved. */ #include +#include +#include #include #include #include #include #include -__global__ void dequantize_q8_0_kernel(float* output, const float* scales, const int8_t* qs, int num_blocks, int blk_size) { +__global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; - for (long long block_id=global_idx; block_id(data + block_id * blk_size + 80))); - const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 82))); + const float d = __half2float(*(reinterpret_cast(data + block_id * blk_size + 80))); + const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 82))); const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16); @@ -70,7 +104,75 @@ __global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size } } -__global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { +__global__ void dequantize_q2_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) { + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (long long block_id=global_idx; block_id(data + block_id * blk_size + 80))); + const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 82))); + + const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16); + + int is = 0; + float dl, ml; + + for (int n = 0; n < 256; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++)); + uint8_t sc = *scales; + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *output_blk++ = __float2half(dl * ((int8_t)((q[l] >> shift) & 3)) - ml); + + scales = (uint8_t*)(data + block_id * blk_size + (is++)); + sc = *scales; + + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *output_blk++ = __float2half(dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml); + + shift += 2; + } + q += 32; + } + } +} + +__global__ void dequantize_q2_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) { + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (long long block_id=global_idx; block_id(data + block_id * blk_size + 80))); + const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 82))); + + const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16); + + int is = 0; + float dl, ml; + + for (int n = 0; n < 256; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++)); + uint8_t sc = *scales; + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l] >> shift) & 3)) - ml); + + scales = (uint8_t*)(data + block_id * blk_size + (is++)); + sc = *scales; + + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml); + + shift += 2; + } + q += 32; + } + } +} + +__global__ void dequantize_q3_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; const uint32_t kmask1 = 0x03030303; @@ -80,7 +182,7 @@ __global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size uint32_t aux[4]; const int8_t * scales = (const int8_t*)aux; - const float d_all = __half2float(*(reinterpret_cast(data + block_id * blk_size + 108))); + const float d_all = __half2float(*(reinterpret_cast(data + block_id * blk_size + 108))); const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 32); const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0); @@ -126,16 +228,128 @@ __global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size } } +__global__ void dequantize_q3_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) { + + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + for (long long block_id=global_idx; block_id(data + block_id * blk_size + 108))); + + const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 32); + const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0); + uint8_t m = 1; + + + uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96); + + for (int i = 0; i < 3; i++) { + aux[i] = 0; + for (int j = 0; j < 4; j++) { + aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8); + } + } + + uint32_t tmp = aux[2]; + aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + int is = 0; + float dl; + for (int n = 0; n < 256; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) { + *output_blk++ = __float2half(dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4))); + } + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) { + *output_blk++ = __float2half(dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4))); + } + + shift += 2; + m <<= 1; + } + q += 32; + } + } +} + +__global__ void dequantize_q3_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) { + + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + for (long long block_id=global_idx; block_id(data + block_id * blk_size + 108))); + + const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 32); + const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0); + uint8_t m = 1; + + + uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96); + + for (int i = 0; i < 3; i++) { + aux[i] = 0; + for (int j = 0; j < 4; j++) { + aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8); + } + } + + uint32_t tmp = aux[2]; + aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + int is = 0; + float dl; + for (int n = 0; n < 256; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) { + *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4))); + } + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) { + *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4))); + } + + shift += 2; + m <<= 1; + } + q += 32; + } + } +} + + +__global__ void dequantize_q4_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * 144 + 0))); - const float min = __half2float(*(reinterpret_cast(data + block_id * 144 + 2))); + const float d = __half2float(*(reinterpret_cast(data + block_id * 144 + 0))); + const float min = __half2float(*(reinterpret_cast(data + block_id * 144 + 2))); int is = 0; uint8_t sc, m; for (int j = 0; j < blk_size; j += 64) { @@ -151,13 +365,61 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size } } -__global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { +__global__ void dequantize_q4_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) { + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (long long block_id=global_idx; block_id(data + block_id * 144 + 0))); + const float min = __half2float(*(reinterpret_cast(data + block_id * 144 + 2))); + int is = 0; + uint8_t sc, m; + for (int j = 0; j < blk_size; j += 64) { + uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4); + get_scale_min_k4(is + 0, scales, &sc, &m); + const float d1 = d * sc; const float m1 = min * m; + get_scale_min_k4(is + 1, scales, &sc, &m); + const float d2 = d * sc; const float m2 = min * m; + for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d1 * (q[l] & 0xF) - m1); + for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d2 * (q[l] >> 4) - m2); + q += 32; is += 2; + } + } +} + +__global__ void dequantize_q4_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) { + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (long long block_id=global_idx; block_id(data + block_id * 144 + 0))); + const float min = __half2float(*(reinterpret_cast(data + block_id * 144 + 2))); + int is = 0; + uint8_t sc, m; + for (int j = 0; j < blk_size; j += 64) { + uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4); + get_scale_min_k4(is + 0, scales, &sc, &m); + const float d1 = d * sc; const float m1 = min * m; + get_scale_min_k4(is + 1, scales, &sc, &m); + const float d2 = d * sc; const float m2 = min * m; + for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d1 * (q[l] & 0xF) - m1); + for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d2 * (q[l] >> 4) - m2); + q += 32; is += 2; + } + } +} + +__global__ void dequantize_q5_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){ float* __restrict__ output_blk = (float*)(output + block_id * 256); - const float d = __half2float(*(reinterpret_cast(data + block_id * blk_size + 0))); - const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 2))); + const float d = __half2float(*(reinterpret_cast(data + block_id * blk_size + 0))); + const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 2))); const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16); const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48); @@ -180,11 +442,69 @@ __global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size } } -__global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { +__global__ void dequantize_q5_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) { + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){ + __half* __restrict__ output_blk = (__half*)(output + block_id * 256); + + const float d = __half2float(*(reinterpret_cast(data + block_id * blk_size + 0))); + const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 2))); + + const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16); + const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48); + + int is = 0; + uint8_t sc, m; + uint8_t u1 = 1, u2 = 2; + uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4); + + for (int j = 0; j < 256; j += 64) { + get_scale_min_k4(is + 0, scales, &sc, &m); + const float d1 = d * sc; const float m1 = min * m; + get_scale_min_k4(is + 1, scales, &sc, &m); + const float d2 = d * sc; const float m2 = min * m; + for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1); + for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2); + ql += 32; is += 2; + u1 <<= 2; u2 <<= 2; + } + } +} + +__global__ void dequantize_q5_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) { + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){ + nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * 256); + + const float d = __half2float(*(reinterpret_cast(data + block_id * blk_size + 0))); + const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 2))); + + const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16); + const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48); + + int is = 0; + uint8_t sc, m; + uint8_t u1 = 1, u2 = 2; + uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4); + + for (int j = 0; j < 256; j += 64) { + get_scale_min_k4(is + 0, scales, &sc, &m); + const float d1 = d * sc; const float m1 = min * m; + get_scale_min_k4(is + 1, scales, &sc, &m); + const float d2 = d * sc; const float m2 = min * m; + for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1); + for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2); + ql += 32; is += 2; + u1 <<= 2; u2 <<= 2; + } + } +} + +__global__ void dequantize_q6_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * blk_size + 208))); + const float d = __half2float(*(reinterpret_cast(data + block_id * blk_size + 208))); const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size); const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128); @@ -212,14 +532,78 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size } } +__global__ void dequantize_q6_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) { + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (long long block_id=global_idx; block_id(data + block_id * blk_size + 208))); + + const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size); + const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128); + const int8_t * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192); + + + //if (blk_size == 256){ + for (int n = 0; n < blk_size; n += 128) { + for (int l = 0; l < 32; ++l) { + int is = l/16; + const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + output_blk[l + 0] = __float2half(d * sc[is + 0] * q1); + output_blk[l + 32] = __float2half(d * sc[is + 2] * q2); + output_blk[l + 64] = __float2half(d * sc[is + 4] * q3); + output_blk[l + 96] = __float2half(d * sc[is + 6] * q4); + } + output_blk += 128; + ql += 64; + qh += 32; + sc += 8; + } + } +} + +__global__ void dequantize_q6_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) { + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (long long block_id=global_idx; block_id(data + block_id * blk_size + 208))); + + const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size); + const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128); + const int8_t * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192); + + + //if (blk_size == 256){ + for (int n = 0; n < blk_size; n += 128) { + for (int l = 0; l < 32; ++l) { + int is = l/16; + const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + output_blk[l + 0] = __float2bfloat16(d * sc[is + 0] * q1); + output_blk[l + 32] = __float2bfloat16(d * sc[is + 2] * q2); + output_blk[l + 64] = __float2bfloat16(d * sc[is + 4] * q3); + output_blk[l + 96] = __float2bfloat16(d * sc[is + 6] * q4); + } + output_blk += 128; + ql += 64; + qh += 32; + sc += 8; + } + } +} + static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; -__global__ void dequantize_iq4_xs_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { +__global__ void dequantize_iq4_xs_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * blk_size))); - const uint16_t scales_h = *(reinterpret_cast(data + block_id * blk_size + 2)); + const float d = __half2float(*(reinterpret_cast(data + block_id * blk_size))); + const uint16_t scales_h = *(reinterpret_cast(data + block_id * blk_size + 2)); const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2); const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4); @@ -236,152 +620,267 @@ __global__ void dequantize_iq4_xs_kernel(int8_t* data, float* output, int blk_si } } -torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device) { - int num_blocks = data.numel() / blk_size; +__global__ void dequantize_iq4_xs_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) { + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (long long block_id=global_idx; block_id(data + block_id * blk_size))); + const uint16_t scales_h = *(reinterpret_cast(data + block_id * blk_size + 2)); + const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2); + const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4); + + for (int ib = 0; ib < 8; ++ib) { + const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4); + const float dl = d * (ls - 32); + for (int j = 0; j < 16; ++j) { + output_blk[j + 0] = __float2half(dl * kvalues_iq4nl[qs[j] & 0xf]); + output_blk[j + 16] = __float2half(dl * kvalues_iq4nl[qs[j] >> 4]); + } + output_blk += 32; + qs += 16; + } + } +} + +__global__ void dequantize_iq4_xs_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) { + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (long long block_id=global_idx; block_id(data + block_id * blk_size))); + const uint16_t scales_h = *(reinterpret_cast(data + block_id * blk_size + 2)); + const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2); + const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4); + + for (int ib = 0; ib < 8; ++ib) { + const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4); + const float dl = d * (ls - 32); + for (int j = 0; j < 16; ++j) { + output_blk[j + 0] = __float2bfloat16(dl * kvalues_iq4nl[qs[j] & 0xf]); + output_blk[j + 16] = __float2bfloat16(dl * kvalues_iq4nl[qs[j] >> 4]); + } + output_blk += 32; + qs += 16; + } + } +} + +torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) { + int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); - // create gpu - auto options_scales = torch::TensorOptions().dtype(torch::kFloat32).device(device).memory_format(torch::MemoryFormat::Contiguous); - auto options_qs = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); - auto scales_gpu = torch::empty({{num_blocks, 1}}, options_scales); - auto qs_gpu = torch::empty({num_blocks, 32}, options_qs); - // read on cpu - options_scales = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCPU); - options_qs = torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU); + auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); + auto data_gpu = torch::empty({ num_bytes }, options); - // // reinterpret - auto scales = torch::from_blob(data.data_ptr(), {num_blocks, 1 + 16}, options_scales).slice(1, 0, 1); - auto qs = torch::from_blob(data.data_ptr(), {num_blocks, 2 + 32}, options_qs).slice(1, 2); - - auto scales_f32 = scales.to(torch::kFloat32); - scales_gpu.copy_(scales_f32, false); - qs_gpu.copy_(qs, false); + cudaMemcpy(data_gpu.data_ptr(), data, num_bytes, cudaMemcpyHostToDevice); + //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros_like(qs, torch::dtype(torch::kFloat32).device(device)); + auto output = torch::zeros({ num_blocks, 32 }, torch::dtype(target_dtype).device(device)); - // Launch kernel - dequantize_q8_0_kernel<<< 512, 256 >>>( - output.data_ptr(), scales_gpu.data_ptr(), qs_gpu.data_ptr(), num_blocks, 32); + switch (target_dtype) { + case torch::kFloat16: + dequantize_q8_0_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, num_blocks); + break; + case torch::kBFloat16: + dequantize_q8_0_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + break; + case torch::kFloat32: + dequantize_q8_0_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + break; + default: + printf("target type not support\n"); + exit(0); + } cudaDeviceSynchronize(); return output; } -torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device) { +torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) { // data.numel%blk_size should be 0, else raise err - int num_blocks = data.numel() / blk_size; + int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); - auto data_gpu = torch::empty({data.numel()}, options); + auto data_gpu = torch::empty({num_bytes}, options); - data_gpu.copy_(data, false); + cudaMemcpy(data_gpu.data_ptr(), data, num_bytes, cudaMemcpyHostToDevice); + //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device)); - - // Launch kernel - dequantize_q6_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); - // dequantize_q6_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr(), output.data_ptr(), 256, num_blocks); + auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + switch (target_dtype) { + case torch::kFloat16: + dequantize_q6_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, num_blocks); + break; + case torch::kBFloat16: + dequantize_q6_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + break; + case torch::kFloat32: + dequantize_q6_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + break; + default: + printf("target type not support\n"); + exit(0); + } cudaDeviceSynchronize(); return output; } -torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device) { - int num_blocks = data.numel() / blk_size; +torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) { + int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); - auto data_gpu = torch::empty({data.numel()}, options); + auto data_gpu = torch::empty({num_bytes}, options); - data_gpu.copy_(data, false); + cudaMemcpy(data_gpu.data_ptr(), data, num_bytes, cudaMemcpyHostToDevice); + //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device)); - - // Launch kernel - dequantize_q5_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + switch (target_dtype) { + case torch::kFloat16: + dequantize_q5_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, num_blocks); + break; + case torch::kBFloat16: + dequantize_q5_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + break; + case torch::kFloat32: + dequantize_q5_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + break; + default: + printf("target type not support\n"); + exit(0); + } cudaDeviceSynchronize(); return output; } -torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device) { +torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) { // data.numel%blk_size should be 0, else raise err - int num_blocks = data.numel() / blk_size; + int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); - auto data_gpu = torch::empty({data.numel()}, options); + auto data_gpu = torch::empty({num_bytes}, options); - data_gpu.copy_(data, false); + cudaMemcpy(data_gpu.data_ptr(), data, num_bytes, cudaMemcpyHostToDevice); + //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device)); - - // Launch kernel - dequantize_q4_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr(), output.data_ptr(), 256, num_blocks); + auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + switch (target_dtype) { + case torch::kFloat16: + dequantize_q4_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, num_blocks); + break; + case torch::kBFloat16: + dequantize_q4_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + break; + case torch::kFloat32: + dequantize_q4_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + break; + default: + printf("target type not support\n"); + exit(0); + } cudaDeviceSynchronize(); return output; } -torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device) { - int num_blocks = data.numel() / blk_size; +torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) { + int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); - auto data_gpu = torch::empty({data.numel()}, options); + auto data_gpu = torch::empty({num_bytes}, options); - data_gpu.copy_(data, false); + cudaMemcpy(data_gpu.data_ptr(), data, num_bytes, cudaMemcpyHostToDevice); + //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device)); - - // Launch kernel - dequantize_q3_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + switch (target_dtype) { + case torch::kFloat16: + dequantize_q3_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, num_blocks); + break; + case torch::kBFloat16: + dequantize_q3_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + break; + case torch::kFloat32: + dequantize_q3_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + break; + default: + printf("target type not support\n"); + exit(0); + } cudaDeviceSynchronize(); return output; } -torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device) { - int num_blocks = data.numel() / blk_size; +torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) { + int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); - auto data_gpu = torch::empty({data.numel()}, options); + auto data_gpu = torch::empty({num_bytes}, options); - data_gpu.copy_(data, false); + cudaMemcpy(data_gpu.data_ptr(), data, num_bytes, cudaMemcpyHostToDevice); + //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device)); - - // Launch kernel - dequantize_q2_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + switch (target_dtype) { + case torch::kFloat16: + dequantize_q2_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, num_blocks); + break; + case torch::kBFloat16: + dequantize_q2_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + break; + case torch::kFloat32: + dequantize_q2_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + break; + default: + printf("target type not support\n"); + exit(0); + } cudaDeviceSynchronize(); return output; } -torch::Tensor dequantize_iq4_xs(torch::Tensor data, int blk_size, torch::Device device) { - int num_blocks = data.numel() / blk_size; +torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) { + int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); - auto data_gpu = torch::empty({data.numel()}, options); + auto data_gpu = torch::empty({num_bytes}, options); - data_gpu.copy_(data, false); + cudaMemcpy(data_gpu.data_ptr(), data, num_bytes, cudaMemcpyHostToDevice); + //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device)); - - // Launch kernel - dequantize_iq4_xs_kernel<<< 512, 256 >>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + switch (target_dtype) { + case torch::kFloat16: + dequantize_iq4_xs_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, num_blocks); + break; + case torch::kBFloat16: + dequantize_iq4_xs_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + break; + case torch::kFloat32: + dequantize_iq4_xs_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + break; + default: + printf("target type not support\n"); + exit(0); + } cudaDeviceSynchronize(); return output; } diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h b/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h index 666d455..b18c799 100644 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h +++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h @@ -13,10 +13,10 @@ #include #include -torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device); -torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device); -torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device); -torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device); -torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device); -torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device); -torch::Tensor dequantize_iq4_xs(torch::Tensor data, int blk_size, torch::Device device); +torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype); +torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype); +torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype); +torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype); +torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype); +torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype); +torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype); diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index fb59a17..d5e74de 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -168,10 +168,7 @@ def local_chat( if mode == 'long_context': assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ "please change max_seq_len in ~/.ktransformers/config.yaml" - torch.set_default_dtype( - torch.bfloat16 - ) # TODO: Remove this, replace dtype using config - + if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled: generated = prefill_and_generate( model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml index 7a44c5d..6fb6586 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml @@ -5,6 +5,18 @@ kwargs: generate_device: "cuda" prefill_device: "cuda" + +- match: + name: "^lm_head$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + - match: name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index edca541..49a3f16 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -25,10 +25,10 @@ class KTransformersThreadContext(TransformersThreadContext): class KTransformersInterface(TransformersInterface): def __init__(self, args: ConfigArgs = default_args): self.args = args - torch.set_default_dtype(torch.bfloat16) torch.set_grad_enabled(False) self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code) config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code) + torch.set_default_dtype(config.torch_dtype) if config.architectures[0] == "Qwen2MoeForCausalLM": config._attn_implementation = "flash_attention_2" diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index eaa1a7d..7ad13c7 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -285,7 +285,7 @@ class GGUFLoader: itemsize = int(np.empty([], dtype = item_type).itemsize) return mmap_data[offset : offset + itemsize * item_count] - def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "gpu")->torch.Tensor: + def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "cuda", target_dtype = torch.get_default_dtype())->torch.Tensor: t = self.tensor_info[name] if device.lower() == "cpu": print(f"loading expert {expert_id} of {name} with CPU") @@ -304,7 +304,7 @@ class GGUFLoader: data = data[offset: offset + block_size * blocks_per_experts] if "cuda" in device.lower(): - values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) + values = GGML_DEQUANTIZE_GPU[ggml_name](data, device, target_dtype) else: values = GGML_DEQUANTIZE[ggml_name](data) values = torch.from_numpy(values) @@ -313,7 +313,7 @@ class GGUFLoader: return values - def load_gguf_tensor(self, name: str, device:str = "cpu")->torch.Tensor: + def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = torch.get_default_dtype())->torch.Tensor: t = self.tensor_info[name] if device.lower() == "cpu": print(f"loading {name} with CPU") @@ -328,16 +328,36 @@ class GGUFLoader: data = self.get_mmap_tensor(name) - if "cuda" in device.lower(): - values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) - #values = GGML_DEQUANTIZE[ggml_name](data) - #print("load_gguf_tensor") - #values = torch.from_numpy(values).to(device = device) + block_size = GGML_BLOCK_SIZES[ggml_name] + elements_per_block = GGML_ELEMENTS_PER_BLOCK[ggml_name] + num_elements = int(np.prod(shape)) + num_blocks = num_elements // elements_per_block + + blocks_per_iter = 16384 + if num_blocks > blocks_per_iter: # dequant large tensor + values = torch.empty((num_blocks, elements_per_block), dtype=torch.float, device=device) + for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter): + blocks_begin = i * blocks_per_iter + blocks_end = min(blocks_begin + blocks_per_iter, num_blocks) + if "cuda" in device.lower(): + cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype) + else: + cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size]) + cur_values = torch.from_numpy(cur_values) + + cur_values = cur_values.view(-1, elements_per_block) + values[blocks_begin : blocks_end] = cur_values else: - values = GGML_DEQUANTIZE[ggml_name](data) - values = torch.from_numpy(values) + if "cuda" in device.lower(): + values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) + else: + values = GGML_DEQUANTIZE[ggml_name](data) + values = torch.from_numpy(values) + if ggml_name == "BF16": values = values.view(torch.bfloat16) + + values = values.view(shape[::-1]) if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]: n_head = self.gguf_file_meta['llama.attention.head_count'] @@ -433,14 +453,13 @@ def dequantize_q2_k(data): return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4) -def dequantize_q2_k_gpu(data, device:str ="cuda"): +def dequantize_q2_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["Q2_K"] data = np.frombuffer(data, dtype=data.dtype) device = torch.device(device) # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. - data = torch.from_numpy(data) - return KTransformersOps.dequantize_q2_k(data, block_size, device) + return KTransformersOps.dequantize_q2_k(data.data, data.size, block_size, device, target_dtype) def dequantize_q3_k(data): # C implementation @@ -484,14 +503,13 @@ def dequantize_q3_k(data): (((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7]) ], axis=1) -def dequantize_q3_k_gpu(data, device:str ="cuda"): +def dequantize_q3_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["Q3_K"] data = np.frombuffer(data, dtype=data.dtype) device = torch.device(device) # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. - data = torch.from_numpy(data) - return KTransformersOps.dequantize_q3_k(data, block_size, device) + return KTransformersOps.dequantize_q3_k(data.data, data.size, block_size, device, target_dtype) def dequantize_q4_k(data): # C implementation @@ -515,13 +533,12 @@ def dequantize_q4_k(data): # Dequantize final weights using scales and offsets return factors * qs2 - offsets -def dequantize_q4_k_gpu(data, device:str ="cuda"): +def dequantize_q4_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): data = np.frombuffer(data, dtype=data.dtype) device = torch.device(device) # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. - data = torch.from_numpy(data) - return KTransformersOps.dequantize_q4_k(data, 144, device) + return KTransformersOps.dequantize_q4_k(data.data, data.size, 144, device, target_dtype) def dequantize_q5_k(data): # C implementation @@ -579,14 +596,13 @@ def dequantize_q5_k(data): d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8, ], axis=1) -def dequantize_q5_k_gpu(data, device:str ="cuda"): +def dequantize_q5_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["Q5_K"] data = np.frombuffer(data, dtype=data.dtype) device = torch.device(device) # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. - data = torch.from_numpy(data) - return KTransformersOps.dequantize_q5_k(data, block_size, device) + return KTransformersOps.dequantize_q5_k(data.data, data.size, block_size, device, target_dtype) def dequantize_q6_k(data): # C implementation @@ -637,13 +653,12 @@ def dequantize_q6_k(data): ], axis=1) # @torch.jit.script -def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda"): +def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["Q6_K"] device = torch.device(device) num_blocks = len(data) // block_size data = np.frombuffer(data, dtype=data.dtype) - data = torch.from_numpy(data) - return KTransformersOps.dequantize_q6_k(data, block_size, device) + return KTransformersOps.dequantize_q6_k(data.data, data.size, block_size, device, target_dtype) kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8) @@ -677,13 +692,12 @@ def dequantize_iq4_xs(data): return y.flatten() -def dequantize_iq4_xs_gpu(data: np.ndarray, device:str = "cuda"): +def dequantize_iq4_xs_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["IQ4_XS"] device = torch.device(device) num_blocks = len(data) // block_size data = np.frombuffer(data, dtype=data.dtype) - data = torch.from_numpy(data) - return KTransformersOps.dequantize_iq4_xs(data, block_size, device) + return KTransformersOps.dequantize_iq4_xs(data.data, data.size, block_size, device, target_dtype) def dequantize_q4_0(data): # C implementation @@ -700,7 +714,7 @@ def dequantize_q4_0(data): scales * ((qs >> 4).astype(np.int8) - 8), ], axis=1) -def dequantize_q4_0_gpu(data): +def dequantize_q4_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()): raise NotImplementedError() def dequantize_q5_0(data): @@ -724,7 +738,7 @@ def dequantize_q5_0(data): scales * x1, ], axis=1) -def dequantize_q5_0_gpu(data): +def dequantize_q5_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()): raise NotImplementedError() def dequantize_q8_0(data): @@ -736,20 +750,19 @@ def dequantize_q8_0(data): qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:] return scales * qs -def dequantize_q8_0_gpu(data, device:str = "cuda"): +def dequantize_q8_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()): # C struct definition # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43 num_blocks = len(data) // GGML_BLOCK_SIZES["Q8_0"] device = torch.device(device) data = np.frombuffer(data, dtype=data.dtype) - data = torch.from_numpy(data) - return KTransformersOps.dequantize_q8_0(data, 34, device) + return KTransformersOps.dequantize_q8_0(data.data, data.size, 34, device, target_dtype) def dequantize_f32(data): return np.frombuffer(data, dtype=np.float32) -def dequantize_f32_gpu(data, device): +def dequantize_f32_gpu(data, device, target_dtype = torch.get_default_dtype()): data = np.frombuffer(data, dtype=np.float32) res = torch.from_numpy(data) res_gpu = torch.empty_like(res, device=device) @@ -759,7 +772,7 @@ def dequantize_f32_gpu(data, device): def dequantize_f16(data): return np.frombuffer(data, dtype=np.float16) -def dequantize_f16_gpu(data, device): +def dequantize_f16_gpu(data, device, target_dtype = torch.get_default_dtype()): data = np.frombuffer(data, dtype=np.float16) res = torch.from_numpy(data) res_gpu = torch.empty_like(res, device=device)