mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
support windows support q4_0 and q5_0 dequant on cpu Add CopyRight from pygguf(It was added before, but disappear after merge). Add some TODO in the code.
This commit is contained in:
parent
442e13bc97
commit
0a2fd52cea
32 changed files with 248 additions and 108 deletions
|
@ -7,6 +7,9 @@ Date : 2024-07-26 08:48:54
|
|||
Version : 1.0.0
|
||||
LastEditors : Azure
|
||||
LastEditTime : 2024-07-26 09:28:25
|
||||
Adapted from https://github.com/99991/pygguf/blob/main/gguf.py
|
||||
Copyright (c) 2023-2024 The ggml authors
|
||||
Copyright (c) 2024 Thomas Germer
|
||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
'''
|
||||
# copied from llama.cpp/gguf-py/gguf/constants.py to satisfy dependence of gguf
|
||||
|
@ -95,7 +98,8 @@ def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantization
|
|||
|
||||
GGML_TYPES = {
|
||||
"F32": 0,
|
||||
"F16": 1,
|
||||
"Q4_0": 2,
|
||||
"Q5_0": 6,
|
||||
"Q8_0": 8,
|
||||
"Q2_K": 10,
|
||||
"Q3_K": 11,
|
||||
|
@ -108,7 +112,8 @@ GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()}
|
|||
|
||||
GGML_BLOCK_SIZES = {
|
||||
"F32": 4,
|
||||
"F16": 2,
|
||||
"Q4_0": 2 + 16,
|
||||
"Q5_0": 2 + 4 + 16,
|
||||
"Q8_0": 2 + 32,
|
||||
"Q2_K": 256 // 16 + 256 // 4 + 2 + 2,
|
||||
"Q3_K": 256 // 8 + 256 // 4 + 12 + 2,
|
||||
|
@ -119,7 +124,8 @@ GGML_BLOCK_SIZES = {
|
|||
|
||||
GGML_ELEMENTS_PER_BLOCK = {
|
||||
"F32": 1,
|
||||
"F16": 1,
|
||||
"Q4_0": 32,
|
||||
"Q5_0": 32,
|
||||
"Q8_0": 32,
|
||||
"Q2_K": 256,
|
||||
"Q3_K": 256,
|
||||
|
@ -128,14 +134,6 @@ GGML_ELEMENTS_PER_BLOCK = {
|
|||
"Q6_K": 256,
|
||||
}
|
||||
|
||||
# DATA_TYPES = {
|
||||
# "uint32": 4,
|
||||
# "int32": 5,
|
||||
# "float32": 6,
|
||||
# "string": 8,
|
||||
# "array": 9,
|
||||
# "uint64": 10,
|
||||
# }
|
||||
DATA_TYPES = {
|
||||
"uint8": 0,
|
||||
"int8": 1,
|
||||
|
@ -272,7 +270,7 @@ class GGUFLoader:
|
|||
|
||||
def load_gguf_tensor(self, name: str, device:str = "cpu")->torch.Tensor:
|
||||
t = self.tensor_info[name]
|
||||
|
||||
|
||||
shape = t["shape"]
|
||||
ggml_type = t["ggml_type"]
|
||||
|
||||
|
@ -282,10 +280,12 @@ class GGUFLoader:
|
|||
ggml_name = GGML_NAMES[ggml_type]
|
||||
|
||||
data = self.get_mmap_tensor(name)
|
||||
|
||||
|
||||
if "cuda" in device.lower():
|
||||
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
|
||||
#values = GGML_DEQUANTIZE[ggml_name](data)
|
||||
#print("load_gguf_tensor")
|
||||
#values = torch.from_numpy(values).to(device = device)
|
||||
else:
|
||||
values = GGML_DEQUANTIZE[ggml_name](data)
|
||||
values = torch.from_numpy(values)
|
||||
|
@ -375,7 +375,7 @@ def dequantize_q2_k(data):
|
|||
return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
|
||||
|
||||
def dequantize_q2_k_gpu(data):
|
||||
pass
|
||||
raise NotImplementedError()
|
||||
|
||||
def dequantize_q3_k(data):
|
||||
# C implementation
|
||||
|
@ -420,7 +420,7 @@ def dequantize_q3_k(data):
|
|||
], axis=1)
|
||||
|
||||
def dequantize_q3_k_gpu(data):
|
||||
pass
|
||||
raise NotImplementedError()
|
||||
|
||||
def dequantize_q4_k(data):
|
||||
# C implementation
|
||||
|
@ -429,20 +429,16 @@ def dequantize_q4_k(data):
|
|||
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L116
|
||||
block_size = GGML_BLOCK_SIZES["Q4_K"]
|
||||
num_blocks = len(data) // block_size
|
||||
|
||||
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
|
||||
|
||||
# Casting to float32 because float16 is very slow on CPU
|
||||
scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32)
|
||||
scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32)
|
||||
qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
|
||||
qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32)
|
||||
|
||||
# Dequantize scales and offsets (6 bits and 4 + 2 bits)
|
||||
factors = scale_factors * np.concatenate([qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1)
|
||||
offsets = scale_offsets * np.concatenate([qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1)
|
||||
|
||||
# Interleave low and high quantized bits
|
||||
qs2 = np.stack([qs2 & 0xf, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32)
|
||||
# Dequantize final weights using scales and offsets
|
||||
|
@ -513,7 +509,7 @@ def dequantize_q5_k(data):
|
|||
], axis=1)
|
||||
|
||||
def dequantize_q5_k_gpu(data):
|
||||
pass
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def dequantize_q6_k(data):
|
||||
|
@ -573,6 +569,48 @@ def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda"):
|
|||
data = torch.from_numpy(data)
|
||||
return KTransformersOps.dequantize_q6_k(data, 210, device)
|
||||
|
||||
def dequantize_q4_0(data):
|
||||
# C implementation
|
||||
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1515
|
||||
# C struct definition
|
||||
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-common.h#L141
|
||||
num_blocks = len(data) // GGML_BLOCK_SIZES["Q4_0"]
|
||||
|
||||
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 8)[:, :1].astype(np.float32)
|
||||
qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 16)[:, 2:]
|
||||
|
||||
return np.concatenate([
|
||||
scales * ((qs & 0xf).astype(np.int8) - 8),
|
||||
scales * ((qs >> 4).astype(np.int8) - 8),
|
||||
], axis=1)
|
||||
|
||||
def dequantize_q4_0_gpu(data):
|
||||
raise NotImplementedError()
|
||||
|
||||
def dequantize_q5_0(data):
|
||||
# C implementation
|
||||
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1556
|
||||
# C struct definition
|
||||
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-common.h#L161
|
||||
num_blocks = len(data) // GGML_BLOCK_SIZES["Q5_0"]
|
||||
|
||||
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 2 + 8)[:, :1].astype(np.float32)
|
||||
qh = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2:2 + 4]
|
||||
qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2 + 4:]
|
||||
|
||||
bits = np.unpackbits(qh, axis=-1, bitorder="little")
|
||||
|
||||
x0 = ((qs & 0xf).astype(np.int8) | (bits[:, :16] << 4)) - 16
|
||||
x1 = ((qs >> 4).astype(np.int8) | (bits[:, 16:] << 4)) - 16
|
||||
|
||||
return np.concatenate([
|
||||
scales * x0,
|
||||
scales * x1,
|
||||
], axis=1)
|
||||
|
||||
def dequantize_q5_0_gpu(data):
|
||||
raise NotImplementedError()
|
||||
|
||||
def dequantize_q8_0(data):
|
||||
# C struct definition
|
||||
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43
|
||||
|
@ -615,6 +653,8 @@ def dequantize_f16_gpu(data, device):
|
|||
GGML_DEQUANTIZE = {
|
||||
"F32": dequantize_f32,
|
||||
"F16": dequantize_f16,
|
||||
"Q4_0": dequantize_q4_0,
|
||||
"Q5_0": dequantize_q5_0,
|
||||
"Q8_0": dequantize_q8_0,
|
||||
"Q2_K": dequantize_q2_k,
|
||||
"Q3_K": dequantize_q3_k,
|
||||
|
@ -626,6 +666,8 @@ GGML_DEQUANTIZE = {
|
|||
GGML_DEQUANTIZE_GPU = {
|
||||
"F32": dequantize_f32_gpu,
|
||||
"F16": dequantize_f16_gpu,
|
||||
"Q4_0": dequantize_q4_0_gpu,
|
||||
"Q5_0": dequantize_q5_0_gpu,
|
||||
"Q8_0": dequantize_q8_0_gpu,
|
||||
"Q2_K": dequantize_q2_k_gpu,
|
||||
"Q3_K": dequantize_q3_k_gpu,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue