mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-13 00:29:59 +00:00
support safetensor load, delete architectures argument
This commit is contained in:
parent
900a7f7c3e
commit
c6aa379de2
30 changed files with 1075 additions and 328 deletions
|
@ -25,7 +25,6 @@ import os
|
|||
from enum import IntEnum
|
||||
import torch
|
||||
import KTransformersOps
|
||||
from .custom_loader import SafeTensorLoader
|
||||
import ctypes
|
||||
import math
|
||||
|
||||
|
@ -166,238 +165,6 @@ DATA_TYPES = {
|
|||
"FP8": 13,
|
||||
}
|
||||
|
||||
class GGUFLoader:
|
||||
tensor_info: dict
|
||||
gguf_path: str
|
||||
tensor_file_map: dict # {tensor_name: tensor_file_path}
|
||||
gguf_file_meta: dict
|
||||
safetensor_loader: SafeTensorLoader
|
||||
def __init__(self, gguf_path: str):
|
||||
# Check dir exist
|
||||
if not os.path.exists(gguf_path):
|
||||
raise FileNotFoundError(f"GGUF dir not found: {gguf_path}")
|
||||
if os.path.isfile(gguf_path):
|
||||
gguf_path = os.path.dirname(gguf_path)
|
||||
|
||||
self.safetensor_loader = None
|
||||
|
||||
self.tensor_info = {}
|
||||
self.gguf_path = gguf_path
|
||||
self.tensor_file_map = {}
|
||||
self.file_data_map = {}
|
||||
self.gguf_file_meta = {}
|
||||
self.tensor_device_map = {}
|
||||
|
||||
# I know this is ugly, but I don't want to change the original code too much
|
||||
# TODO: merge gguf load and other loads.
|
||||
safetensor_loader = SafeTensorLoader(gguf_path)
|
||||
if safetensor_loader.tensor_file_map:
|
||||
self.safetensor_loader = safetensor_loader
|
||||
return
|
||||
# Walk through all the .gguf files in the directory
|
||||
found_gguf = False
|
||||
for root, dirs, files in os.walk(gguf_path):
|
||||
for file in files:
|
||||
if file.endswith(".gguf"):
|
||||
found_gguf = True
|
||||
file_name = os.path.join(root, file)
|
||||
with open(file_name, "rb") as f:
|
||||
self.load_gguf(f)
|
||||
if file_name not in self.file_data_map:
|
||||
self.file_data_map[file_name] = np.memmap(file_name, mode = 'r')
|
||||
if not found_gguf:
|
||||
raise FileNotFoundError(f"Cannot find any .gguf files in: {gguf_path}")
|
||||
|
||||
def load_gguf(self, f):
|
||||
f.seek(0)
|
||||
assert f.read(4) == b'GGUF'
|
||||
values = struct.unpack("<IQQ", f.read(4+8+8))
|
||||
version, n_tensors, n_kv = values
|
||||
if version != 3:
|
||||
warnings.warn(f"Version {version} has never been tested, might not work")
|
||||
|
||||
info = {}
|
||||
for _ in range(n_kv):
|
||||
name = read_value(f, DATA_TYPES["string"])
|
||||
|
||||
data_type = struct.unpack("<I", f.read(4))[0]
|
||||
|
||||
info[name] = read_value(f, data_type)
|
||||
|
||||
tensor_info = {}
|
||||
for _ in range(n_tensors):
|
||||
name = read_value(f, DATA_TYPES["string"])
|
||||
shape_len = read_value(f, DATA_TYPES["uint32"])
|
||||
shape = [read_value(f, DATA_TYPES["uint64"]) for _ in range(shape_len)]
|
||||
ggml_type = read_value(f, DATA_TYPES["uint32"])
|
||||
bad_offset = read_value(f, DATA_TYPES["uint64"])
|
||||
n_elems = int(math.prod(shape))
|
||||
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
|
||||
n_bytes = n_elems * type_size // block_size
|
||||
np_dims = tuple(reversed(shape))
|
||||
|
||||
item_type: npt.DTypeLike
|
||||
if ggml_type == GGMLQuantizationType.F16:
|
||||
item_count = n_elems
|
||||
item_type = np.float16
|
||||
elif ggml_type == GGMLQuantizationType.F32:
|
||||
item_count = n_elems
|
||||
item_type = np.float32
|
||||
elif ggml_type == GGMLQuantizationType.F64:
|
||||
item_count = n_elems
|
||||
item_type = np.float64
|
||||
elif ggml_type == GGMLQuantizationType.I8:
|
||||
item_count = n_elems
|
||||
item_type = np.int8
|
||||
elif ggml_type == GGMLQuantizationType.I16:
|
||||
item_count = n_elems
|
||||
item_type = np.int16
|
||||
elif ggml_type == GGMLQuantizationType.I32:
|
||||
item_count = n_elems
|
||||
item_type = np.int32
|
||||
elif ggml_type == GGMLQuantizationType.I64:
|
||||
item_count = n_elems
|
||||
item_type = np.int64
|
||||
else:
|
||||
item_count = n_bytes
|
||||
item_type = np.uint8
|
||||
np_dims = quant_shape_to_byte_shape(np_dims, ggml_type)
|
||||
|
||||
tensor_info[name] = {
|
||||
"ggml_type": ggml_type,
|
||||
"shape": shape,
|
||||
"bad_offset": bad_offset,
|
||||
"item_type": item_type,
|
||||
"item_count": item_count,
|
||||
"np_dims": np_dims
|
||||
}
|
||||
|
||||
start = f.tell()
|
||||
# Alignment is 32 by default.
|
||||
# https://github.com/ggerganov/ggml/blob/e1daebbf9d38d510ba456c4d50b4500a73ac2b14/docs/gguf.md?plain=1#L253
|
||||
alignment = info.get("general.alignment", 32)
|
||||
|
||||
# Inconveniently, the offset defined in gguf files is relative to the
|
||||
# end of the header and is unaligned.
|
||||
# We need to compute the absolute file offset ourselves instead.
|
||||
for t in tensor_info.values():
|
||||
offset = start + t["bad_offset"]
|
||||
offset += (alignment - offset % alignment) % alignment
|
||||
t["offset"] = offset
|
||||
|
||||
for name in tensor_info:
|
||||
self.tensor_file_map[name] = f.name
|
||||
self.tensor_info.update(tensor_info)
|
||||
self.gguf_file_meta.update(info)
|
||||
|
||||
def get_mmap_tensor(self, name):
|
||||
t = self.tensor_info[name]
|
||||
mmap_data = self.file_data_map[ self.tensor_file_map[name] ]
|
||||
|
||||
offset = t["offset"]
|
||||
item_type = t["item_type"]
|
||||
item_count = t["item_count"]
|
||||
itemsize = int(np.empty([], dtype = item_type).itemsize)
|
||||
return mmap_data[offset : offset + itemsize * item_count]
|
||||
|
||||
def get_undequanted_tensor_and_ggml_type(self, name):
|
||||
t = self.tensor_info[name]
|
||||
data = self.get_mmap_tensor(name)
|
||||
ggml_type = t["ggml_type"]
|
||||
data = torch.from_numpy(data)
|
||||
return data, ggml_type
|
||||
|
||||
def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "cuda", target_dtype = torch.get_default_dtype())->torch.Tensor:
|
||||
t = self.tensor_info[name]
|
||||
if device.lower() == "cpu":
|
||||
print(f"loading expert {expert_id} of {name} with CPU")
|
||||
shape = t["shape"]
|
||||
ggml_type = t["ggml_type"]
|
||||
if ggml_type not in GGML_NAMES:
|
||||
raise NotImplementedError(f"ggml_type {ggml_type} not implemented")
|
||||
ggml_name = GGML_NAMES[ggml_type]
|
||||
|
||||
# TODO: experts may fused in quant block, split it
|
||||
assert elements_per_expert % GGML_ELEMENTS_PER_BLOCK[ggml_name] == 0, "experts may fused in quant block, please use CPU dequant"
|
||||
|
||||
blocks_per_experts = elements_per_expert // GGML_ELEMENTS_PER_BLOCK[ggml_name]
|
||||
block_size = GGML_BLOCK_SIZES[ggml_name]
|
||||
offset = expert_id * block_size * blocks_per_experts
|
||||
data = data[offset: offset + block_size * blocks_per_experts]
|
||||
|
||||
if "cuda" in device.lower():
|
||||
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device, target_dtype)
|
||||
else:
|
||||
values = GGML_DEQUANTIZE[ggml_name](data)
|
||||
values = torch.from_numpy(values.copy())
|
||||
|
||||
if ggml_name == "BF16":
|
||||
values = values.view(torch.bfloat16)
|
||||
values = values.view(shape[-2::-1])
|
||||
|
||||
return values
|
||||
|
||||
def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = None)->torch.Tensor:
|
||||
t = self.tensor_info[name]
|
||||
if device.lower() == "cpu":
|
||||
print(f"loading {name} with CPU")
|
||||
if target_dtype == None:
|
||||
target_dtype = torch.get_default_dtype()
|
||||
|
||||
shape = t["shape"]
|
||||
ggml_type = t["ggml_type"]
|
||||
|
||||
if ggml_type not in GGML_NAMES:
|
||||
raise NotImplementedError(f"ggml_type {ggml_type} not implemented")
|
||||
|
||||
ggml_name = GGML_NAMES[ggml_type]
|
||||
|
||||
data = self.get_mmap_tensor(name)
|
||||
|
||||
block_size = GGML_BLOCK_SIZES[ggml_name]
|
||||
elements_per_block = GGML_ELEMENTS_PER_BLOCK[ggml_name]
|
||||
num_elements = int(np.prod(shape))
|
||||
num_blocks = num_elements // elements_per_block
|
||||
|
||||
blocks_per_iter = 16384
|
||||
if num_blocks > blocks_per_iter: # dequant large tensor
|
||||
values = torch.empty((num_blocks, elements_per_block), dtype=target_dtype, device=device)
|
||||
for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter):
|
||||
blocks_begin = i * blocks_per_iter
|
||||
blocks_end = min(blocks_begin + blocks_per_iter, num_blocks)
|
||||
if "cuda" in device.lower():
|
||||
cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype)
|
||||
else:
|
||||
cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size])
|
||||
cur_values = torch.from_numpy(cur_values.copy())
|
||||
|
||||
cur_values = cur_values.view(-1, elements_per_block)
|
||||
if ggml_name == "BF16":
|
||||
cur_values = cur_values.view(torch.bfloat16)
|
||||
values[blocks_begin : blocks_end] = cur_values
|
||||
else:
|
||||
if "cuda" in device.lower():
|
||||
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
|
||||
else:
|
||||
values = GGML_DEQUANTIZE[ggml_name](data)
|
||||
values = torch.from_numpy(values)
|
||||
|
||||
if ggml_name == "BF16":
|
||||
values = values.view(torch.bfloat16)
|
||||
|
||||
|
||||
values = values.view(shape[::-1])
|
||||
if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
|
||||
n_head = self.gguf_file_meta['llama.attention.head_count']
|
||||
values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:])
|
||||
.swapaxes(1, 2)
|
||||
.reshape(values.shape))
|
||||
elif "attn_k" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
|
||||
n_head = self.gguf_file_meta['llama.attention.head_count_kv']
|
||||
values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:])
|
||||
.swapaxes(1, 2)
|
||||
.reshape(values.shape))
|
||||
return values
|
||||
|
||||
def read_value(f, data_type):
|
||||
if data_type == DATA_TYPES["string"]:
|
||||
|
@ -921,6 +688,7 @@ def translate_name_to_gguf(name):
|
|||
name = name.replace(".gate_up_proj.", ".up_proj")
|
||||
|
||||
name = name.replace(".mlp.shared_experts.down_proj", ".ffn_down_shexp")
|
||||
name = name.replace(".mlp.gate.e_score_correction_bias", ".exp_probs_b.bias")
|
||||
name = name.replace(".mlp.gate", ".ffn_gate_inp")
|
||||
name = name.replace(".mlp.shared_experts.gate_proj", ".ffn_gate_shexp")
|
||||
name = name.replace(".mlp.shared_experts.up_proj", ".ffn_up_shexp")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue