[ADD] support multi-gpu qlen>1 q5_k

This commit is contained in:
chenxl 2024-08-12 11:17:29 +00:00
parent f293803156
commit f5f79f5c0e
63 changed files with 3271 additions and 1285 deletions

View file

@ -5,8 +5,11 @@ Description :
Author : Azure-Tang, Boxin Zhang, chenht2022
Date : 2024-07-26 08:48:54
Version : 1.0.0
LastEditors : Azure
LastEditTime : 2024-07-26 09:28:25
LastEditors : kkk1nak0
LastEditTime : 2024-08-09 08:03:44
Adapted from https://github.com/99991/pygguf/blob/main/gguf.py
Copyright (c) 2023-2024 The ggml authors
Copyright (c) 2024 Thomas Germer
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
# copied from llama.cpp/gguf-py/gguf/constants.py to satisfy dependence of gguf
@ -15,6 +18,7 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
import struct
import warnings
import numpy as np
import re
import numpy.typing as npt
from typing import Sequence
import os
@ -96,6 +100,8 @@ def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantization
GGML_TYPES = {
"F32": 0,
"F16": 1,
"Q4_0": 2,
"Q5_0": 6,
"Q8_0": 8,
"Q2_K": 10,
"Q3_K": 11,
@ -109,6 +115,8 @@ GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()}
GGML_BLOCK_SIZES = {
"F32": 4,
"F16": 2,
"Q4_0": 2 + 16,
"Q5_0": 2 + 4 + 16,
"Q8_0": 2 + 32,
"Q2_K": 256 // 16 + 256 // 4 + 2 + 2,
"Q3_K": 256 // 8 + 256 // 4 + 12 + 2,
@ -120,6 +128,8 @@ GGML_BLOCK_SIZES = {
GGML_ELEMENTS_PER_BLOCK = {
"F32": 1,
"F16": 1,
"Q4_0": 32,
"Q5_0": 32,
"Q8_0": 32,
"Q2_K": 256,
"Q3_K": 256,
@ -128,14 +138,6 @@ GGML_ELEMENTS_PER_BLOCK = {
"Q6_K": 256,
}
# DATA_TYPES = {
# "uint32": 4,
# "int32": 5,
# "float32": 6,
# "string": 8,
# "array": 9,
# "uint64": 10,
# }
DATA_TYPES = {
"uint8": 0,
"int8": 1,
@ -167,6 +169,7 @@ class GGUFLoader:
self.tensor_file_map = {}
self.file_data_map = {}
self.gguf_file_meta = {}
self.tensor_device_map = {}
# Walk through all the .gguf files in the directory
for root, dirs, files in os.walk(gguf_path):
@ -272,7 +275,7 @@ class GGUFLoader:
def load_gguf_tensor(self, name: str, device:str = "cpu")->torch.Tensor:
t = self.tensor_info[name]
shape = t["shape"]
ggml_type = t["ggml_type"]
@ -282,15 +285,28 @@ class GGUFLoader:
ggml_name = GGML_NAMES[ggml_type]
data = self.get_mmap_tensor(name)
if "cuda" in device.lower():
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
#values = GGML_DEQUANTIZE[ggml_name](data)
#print("load_gguf_tensor")
#values = torch.from_numpy(values).to(device = device)
else:
values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values)
return values.view(shape[::-1])
values = values.view(shape[::-1])
if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
n_head = self.gguf_file_meta['llama.attention.head_count']
values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:])
.swapaxes(1, 2)
.reshape(values.shape))
elif "attn_k" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
n_head = self.gguf_file_meta['llama.attention.head_count_kv']
values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:])
.swapaxes(1, 2)
.reshape(values.shape))
return values
def read_value(f, data_type):
if data_type == DATA_TYPES["string"]:
@ -375,7 +391,7 @@ def dequantize_q2_k(data):
return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
def dequantize_q2_k_gpu(data):
pass
raise NotImplementedError()
def dequantize_q3_k(data):
# C implementation
@ -420,7 +436,7 @@ def dequantize_q3_k(data):
], axis=1)
def dequantize_q3_k_gpu(data):
pass
raise NotImplementedError()
def dequantize_q4_k(data):
# C implementation
@ -429,20 +445,16 @@ def dequantize_q4_k(data):
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L116
block_size = GGML_BLOCK_SIZES["Q4_K"]
num_blocks = len(data) // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
# Casting to float32 because float16 is very slow on CPU
scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32)
scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32)
qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32)
# Dequantize scales and offsets (6 bits and 4 + 2 bits)
factors = scale_factors * np.concatenate([qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1)
offsets = scale_offsets * np.concatenate([qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1)
# Interleave low and high quantized bits
qs2 = np.stack([qs2 & 0xf, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32)
# Dequantize final weights using scales and offsets
@ -512,9 +524,14 @@ def dequantize_q5_k(data):
d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,
], axis=1)
def dequantize_q5_k_gpu(data):
pass
def dequantize_q5_k_gpu(data, device:str ="cuda"):
block_size = GGML_BLOCK_SIZES["Q5_K"]
data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data)
return KTransformersOps.dequantize_q5_k(data, block_size, device)
def dequantize_q6_k(data):
# C implementation
@ -571,7 +588,49 @@ def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda"):
num_blocks = len(data) // block_size
data = np.frombuffer(data, dtype=data.dtype)
data = torch.from_numpy(data)
return KTransformersOps.dequantize_q6_k(data, 210, device)
return KTransformersOps.dequantize_q6_k(data, block_size, device)
def dequantize_q4_0(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1515
# C struct definition
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-common.h#L141
num_blocks = len(data) // GGML_BLOCK_SIZES["Q4_0"]
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 8)[:, :1].astype(np.float32)
qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 16)[:, 2:]
return np.concatenate([
scales * ((qs & 0xf).astype(np.int8) - 8),
scales * ((qs >> 4).astype(np.int8) - 8),
], axis=1)
def dequantize_q4_0_gpu(data):
raise NotImplementedError()
def dequantize_q5_0(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1556
# C struct definition
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-common.h#L161
num_blocks = len(data) // GGML_BLOCK_SIZES["Q5_0"]
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 2 + 8)[:, :1].astype(np.float32)
qh = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2:2 + 4]
qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2 + 4:]
bits = np.unpackbits(qh, axis=-1, bitorder="little")
x0 = ((qs & 0xf).astype(np.int8) | (bits[:, :16] << 4)) - 16
x1 = ((qs >> 4).astype(np.int8) | (bits[:, 16:] << 4)) - 16
return np.concatenate([
scales * x0,
scales * x1,
], axis=1)
def dequantize_q5_0_gpu(data):
raise NotImplementedError()
def dequantize_q8_0(data):
# C struct definition
@ -615,6 +674,8 @@ def dequantize_f16_gpu(data, device):
GGML_DEQUANTIZE = {
"F32": dequantize_f32,
"F16": dequantize_f16,
"Q4_0": dequantize_q4_0,
"Q5_0": dequantize_q5_0,
"Q8_0": dequantize_q8_0,
"Q2_K": dequantize_q2_k,
"Q3_K": dequantize_q3_k,
@ -626,6 +687,8 @@ GGML_DEQUANTIZE = {
GGML_DEQUANTIZE_GPU = {
"F32": dequantize_f32_gpu,
"F16": dequantize_f16_gpu,
"Q4_0": dequantize_q4_0_gpu,
"Q5_0": dequantize_q5_0_gpu,
"Q8_0": dequantize_q8_0_gpu,
"Q2_K": dequantize_q2_k_gpu,
"Q3_K": dequantize_q3_k_gpu,
@ -634,7 +697,34 @@ GGML_DEQUANTIZE_GPU = {
"Q6_K": dequantize_q6_k_gpu,
}
def translate_name_to_gguf_mixtral(name):
replacement_template = {
"w1.weight": "ffn_gate",
"w2.weight": "ffn_down",
"w3.weight": "ffn_up"
}
pattern = re.compile(r"model.layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.(w\d\.weight)")
def replace_match(match):
blk_id = match.group(1)
expert_id = match.group(2)
weight_type = match.group(3)
if weight_type in replacement_template:
return f"blk.{blk_id}.{replacement_template[weight_type]}.{expert_id}.weight"
else:
return match.group(0)
new_name = re.sub(pattern, replace_match, name)
return new_name
def translate_name_to_gguf(name):
name = translate_name_to_gguf_mixtral(name)
name = name.replace("lm_head.", "output.")
name = name.replace("model.embed_tokens.", "token_embd.")
name = name.replace("model.norm.", "output_norm.")
@ -671,9 +761,14 @@ def translate_name_to_gguf(name):
name = name.replace(".mlp.experts.ffn_gate_exps", ".ffn_gate_exps")
name = name.replace(".mlp.experts.ffn_up_exps", ".ffn_up_exps")
name = name.replace(".block_sparse_moe.gate.", ".ffn_gate_inp.")
name = name.replace(".block_sparse_moe.experts", "")
return name
if __name__ == '__main__':
gguf_path = '/mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH'
loader = GGUFLoader(gguf_path)
loader.load_gguf_tensor('token_embd.weight')