mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
Merge remote-tracking branch 'upstream/develop-0.2.2' into support-fp8
This commit is contained in:
commit
ca7366d2db
41 changed files with 1223 additions and 314 deletions
|
@ -26,6 +26,7 @@ from enum import IntEnum
|
|||
import torch
|
||||
import KTransformersOps
|
||||
from .custom_loader import SafeTensorLoader
|
||||
import ctypes
|
||||
|
||||
class GGMLQuantizationType(IntEnum):
|
||||
F32 = 0
|
||||
|
@ -305,7 +306,7 @@ class GGUFLoader:
|
|||
data = torch.from_numpy(data)
|
||||
return data, ggml_type
|
||||
|
||||
def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "gpu")->torch.Tensor:
|
||||
def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "cuda", target_dtype = torch.get_default_dtype())->torch.Tensor:
|
||||
t = self.tensor_info[name]
|
||||
if device.lower() == "cpu":
|
||||
print(f"loading expert {expert_id} of {name} with CPU")
|
||||
|
@ -324,19 +325,21 @@ class GGUFLoader:
|
|||
data = data[offset: offset + block_size * blocks_per_experts]
|
||||
|
||||
if "cuda" in device.lower():
|
||||
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
|
||||
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device, target_dtype)
|
||||
else:
|
||||
values = GGML_DEQUANTIZE[ggml_name](data)
|
||||
values = torch.from_numpy(values)
|
||||
values = torch.from_numpy(values.copy())
|
||||
|
||||
values = values.view(shape[-2::-1])
|
||||
|
||||
return values
|
||||
|
||||
def load_gguf_tensor(self, name: str, device:str = "cpu")->torch.Tensor:
|
||||
def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = None)->torch.Tensor:
|
||||
t = self.tensor_info[name]
|
||||
if device.lower() == "cpu":
|
||||
print(f"loading {name} with CPU")
|
||||
if target_dtype == None:
|
||||
target_dtype = torch.get_default_dtype()
|
||||
|
||||
shape = t["shape"]
|
||||
ggml_type = t["ggml_type"]
|
||||
|
@ -348,16 +351,38 @@ class GGUFLoader:
|
|||
|
||||
data = self.get_mmap_tensor(name)
|
||||
|
||||
if "cuda" in device.lower():
|
||||
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
|
||||
#values = GGML_DEQUANTIZE[ggml_name](data)
|
||||
#print("load_gguf_tensor")
|
||||
#values = torch.from_numpy(values).to(device = device)
|
||||
block_size = GGML_BLOCK_SIZES[ggml_name]
|
||||
elements_per_block = GGML_ELEMENTS_PER_BLOCK[ggml_name]
|
||||
num_elements = int(np.prod(shape))
|
||||
num_blocks = num_elements // elements_per_block
|
||||
|
||||
blocks_per_iter = 16384
|
||||
if num_blocks > blocks_per_iter: # dequant large tensor
|
||||
values = torch.empty((num_blocks, elements_per_block), dtype=target_dtype, device=device)
|
||||
for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter):
|
||||
blocks_begin = i * blocks_per_iter
|
||||
blocks_end = min(blocks_begin + blocks_per_iter, num_blocks)
|
||||
if "cuda" in device.lower():
|
||||
cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype)
|
||||
else:
|
||||
cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size])
|
||||
cur_values = torch.from_numpy(cur_values.copy())
|
||||
|
||||
cur_values = cur_values.view(-1, elements_per_block)
|
||||
if ggml_name == "BF16":
|
||||
cur_values = cur_values.view(torch.bfloat16)
|
||||
values[blocks_begin : blocks_end] = cur_values
|
||||
else:
|
||||
values = GGML_DEQUANTIZE[ggml_name](data)
|
||||
values = torch.from_numpy(values)
|
||||
if "cuda" in device.lower():
|
||||
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
|
||||
else:
|
||||
values = GGML_DEQUANTIZE[ggml_name](data)
|
||||
values = torch.from_numpy(values)
|
||||
|
||||
if ggml_name == "BF16":
|
||||
values = values.view(torch.bfloat16)
|
||||
|
||||
|
||||
values = values.view(shape[::-1])
|
||||
if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
|
||||
n_head = self.gguf_file_meta['llama.attention.head_count']
|
||||
|
@ -456,14 +481,15 @@ def dequantize_q2_k(data):
|
|||
|
||||
return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
|
||||
|
||||
def dequantize_q2_k_gpu(data, device:str ="cuda"):
|
||||
def dequantize_q2_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
|
||||
block_size = GGML_BLOCK_SIZES["Q2_K"]
|
||||
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q2_K"]
|
||||
data = np.frombuffer(data, dtype=data.dtype)
|
||||
device = torch.device(device)
|
||||
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
|
||||
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
|
||||
data = torch.from_numpy(data)
|
||||
return KTransformersOps.dequantize_q2_k(data, block_size, device)
|
||||
c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
|
||||
return KTransformersOps.dequantize_q2_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
|
||||
|
||||
def dequantize_q3_k(data):
|
||||
# C implementation
|
||||
|
@ -507,14 +533,15 @@ def dequantize_q3_k(data):
|
|||
(((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7])
|
||||
], axis=1)
|
||||
|
||||
def dequantize_q3_k_gpu(data, device:str ="cuda"):
|
||||
def dequantize_q3_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
|
||||
block_size = GGML_BLOCK_SIZES["Q3_K"]
|
||||
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q3_K"]
|
||||
data = np.frombuffer(data, dtype=data.dtype)
|
||||
device = torch.device(device)
|
||||
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
|
||||
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
|
||||
data = torch.from_numpy(data)
|
||||
return KTransformersOps.dequantize_q3_k(data, block_size, device)
|
||||
c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
|
||||
return KTransformersOps.dequantize_q3_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
|
||||
|
||||
def dequantize_q4_k(data):
|
||||
# C implementation
|
||||
|
@ -538,13 +565,15 @@ def dequantize_q4_k(data):
|
|||
# Dequantize final weights using scales and offsets
|
||||
return factors * qs2 - offsets
|
||||
|
||||
def dequantize_q4_k_gpu(data, device:str ="cuda"):
|
||||
def dequantize_q4_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
|
||||
block_size = GGML_BLOCK_SIZES["Q4_K"]
|
||||
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q4_K"]
|
||||
data = np.frombuffer(data, dtype=data.dtype)
|
||||
device = torch.device(device)
|
||||
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
|
||||
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
|
||||
data = torch.from_numpy(data)
|
||||
return KTransformersOps.dequantize_q4_k(data, 144, device)
|
||||
c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
|
||||
return KTransformersOps.dequantize_q4_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
|
||||
|
||||
def dequantize_q5_k(data):
|
||||
# C implementation
|
||||
|
@ -602,14 +631,15 @@ def dequantize_q5_k(data):
|
|||
d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,
|
||||
], axis=1)
|
||||
|
||||
def dequantize_q5_k_gpu(data, device:str ="cuda"):
|
||||
def dequantize_q5_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
|
||||
block_size = GGML_BLOCK_SIZES["Q5_K"]
|
||||
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q5_K"]
|
||||
data = np.frombuffer(data, dtype=data.dtype)
|
||||
device = torch.device(device)
|
||||
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
|
||||
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
|
||||
data = torch.from_numpy(data)
|
||||
return KTransformersOps.dequantize_q5_k(data, block_size, device)
|
||||
c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
|
||||
return KTransformersOps.dequantize_q5_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
|
||||
|
||||
def dequantize_q6_k(data):
|
||||
# C implementation
|
||||
|
@ -660,13 +690,14 @@ def dequantize_q6_k(data):
|
|||
], axis=1)
|
||||
|
||||
# @torch.jit.script
|
||||
def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda"):
|
||||
def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()):
|
||||
block_size = GGML_BLOCK_SIZES["Q6_K"]
|
||||
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q6_K"]
|
||||
device = torch.device(device)
|
||||
num_blocks = len(data) // block_size
|
||||
data = np.frombuffer(data, dtype=data.dtype)
|
||||
data = torch.from_numpy(data)
|
||||
return KTransformersOps.dequantize_q6_k(data, block_size, device)
|
||||
c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
|
||||
return KTransformersOps.dequantize_q6_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
|
||||
|
||||
kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8)
|
||||
|
||||
|
@ -700,13 +731,14 @@ def dequantize_iq4_xs(data):
|
|||
|
||||
return y.flatten()
|
||||
|
||||
def dequantize_iq4_xs_gpu(data: np.ndarray, device:str = "cuda"):
|
||||
def dequantize_iq4_xs_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()):
|
||||
block_size = GGML_BLOCK_SIZES["IQ4_XS"]
|
||||
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["IQ4_XS"]
|
||||
device = torch.device(device)
|
||||
num_blocks = len(data) // block_size
|
||||
data = np.frombuffer(data, dtype=data.dtype)
|
||||
data = torch.from_numpy(data)
|
||||
return KTransformersOps.dequantize_iq4_xs(data, block_size, device)
|
||||
c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
|
||||
return KTransformersOps.dequantize_iq4_xs(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
|
||||
|
||||
def dequantize_q4_0(data):
|
||||
# C implementation
|
||||
|
@ -723,7 +755,7 @@ def dequantize_q4_0(data):
|
|||
scales * ((qs >> 4).astype(np.int8) - 8),
|
||||
], axis=1)
|
||||
|
||||
def dequantize_q4_0_gpu(data):
|
||||
def dequantize_q4_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()):
|
||||
raise NotImplementedError()
|
||||
|
||||
def dequantize_q5_0(data):
|
||||
|
@ -747,7 +779,7 @@ def dequantize_q5_0(data):
|
|||
scales * x1,
|
||||
], axis=1)
|
||||
|
||||
def dequantize_q5_0_gpu(data):
|
||||
def dequantize_q5_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()):
|
||||
raise NotImplementedError()
|
||||
|
||||
def dequantize_q8_0(data):
|
||||
|
@ -759,32 +791,41 @@ def dequantize_q8_0(data):
|
|||
qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:]
|
||||
return scales * qs
|
||||
|
||||
def dequantize_q8_0_gpu(data, device:str = "cuda"):
|
||||
def dequantize_q8_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()):
|
||||
# C struct definition
|
||||
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43
|
||||
num_blocks = len(data) // GGML_BLOCK_SIZES["Q8_0"]
|
||||
|
||||
block_size = GGML_BLOCK_SIZES["Q8_0"]
|
||||
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q8_0"]
|
||||
device = torch.device(device)
|
||||
data = np.frombuffer(data, dtype=data.dtype)
|
||||
data = torch.from_numpy(data)
|
||||
return KTransformersOps.dequantize_q8_0(data, 34, device)
|
||||
c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
|
||||
return KTransformersOps.dequantize_q8_0(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
|
||||
|
||||
|
||||
def dequantize_f32(data):
|
||||
return np.frombuffer(data, dtype=np.float32)
|
||||
|
||||
def dequantize_f32_gpu(data, device):
|
||||
def dequantize_f32_gpu(data, device, target_dtype = torch.get_default_dtype()):
|
||||
data = np.frombuffer(data, dtype=np.float32)
|
||||
res = torch.from_numpy(data)
|
||||
res_gpu = torch.empty_like(res, device=device)
|
||||
res = torch.from_numpy(data.copy())
|
||||
res_gpu = torch.empty_like(res, device=device, dtype=target_dtype)
|
||||
res_gpu.copy_(res)
|
||||
return res_gpu
|
||||
|
||||
def dequantize_f16(data):
|
||||
return np.frombuffer(data, dtype=np.float16)
|
||||
|
||||
def dequantize_f16_gpu(data, device):
|
||||
def dequantize_f16_gpu(data, device, target_dtype = torch.get_default_dtype()):
|
||||
data = np.frombuffer(data, dtype=np.float16)
|
||||
res = torch.from_numpy(data)
|
||||
res = torch.from_numpy(data.copy())
|
||||
res_gpu = torch.empty_like(res, device=device, dtype=target_dtype)
|
||||
res_gpu.copy_(res)
|
||||
return res_gpu
|
||||
|
||||
def dequantize_bf16_gpu(data, device, target_dtype = torch.get_default_dtype()):
|
||||
data = np.frombuffer(data, dtype=np.float16)
|
||||
res = torch.from_numpy(data.copy())
|
||||
res_gpu = torch.empty_like(res, device=device)
|
||||
res_gpu.copy_(res)
|
||||
return res_gpu
|
||||
|
@ -807,7 +848,7 @@ GGML_DEQUANTIZE = {
|
|||
GGML_DEQUANTIZE_GPU = {
|
||||
"F32": dequantize_f32_gpu,
|
||||
"F16": dequantize_f16_gpu,
|
||||
"BF16": dequantize_f16_gpu,
|
||||
"BF16": dequantize_bf16_gpu,
|
||||
"Q4_0": dequantize_q4_0_gpu,
|
||||
"Q5_0": dequantize_q5_0_gpu,
|
||||
"Q8_0": dequantize_q8_0_gpu,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue