mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 14:51:06 +00:00
Add data loader to read special weights for fp8; Add special weight process script
This commit is contained in:
parent
7b7c6a657d
commit
581a524f65
10 changed files with 481 additions and 26 deletions
|
@ -1,3 +1,4 @@
|
||||||
|
# Adopted from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -245,7 +245,16 @@ class KExpertsCPU(KExpertsBase):
|
||||||
down_type = None
|
down_type = None
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
|
if self.gguf_loader.safetensor_loader is not None:
|
||||||
|
# using a temp ugly way to temprary load the tensor
|
||||||
|
gate = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.weight").numpy()
|
||||||
|
up = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.weight").numpy()
|
||||||
|
down = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.weight").numpy()
|
||||||
|
gate_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.ggml_type").item()
|
||||||
|
up_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.ggml_type").item()
|
||||||
|
down_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.ggml_type").item()
|
||||||
|
|
||||||
|
elif key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
|
||||||
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
|
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
|
||||||
up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
|
up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
|
||||||
down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
|
down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
|
||||||
|
|
|
@ -67,7 +67,14 @@ class KMoEGateBase(ABC):
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
key = ".".join(key.split(".")[:-1])
|
key = ".".join(key.split(".")[:-1])
|
||||||
if key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info:
|
if self.gguf_loader.safetensor_loader is not None:
|
||||||
|
targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
|
||||||
|
weight = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_inp.weight")
|
||||||
|
e_score_correction_bias = self.gguf_loader.safetensor_loader.load_tensor(key + ".exp_probs_b.bias")
|
||||||
|
weight_type = weight.dtype
|
||||||
|
e_score_correction_bias_type = e_score_correction_bias.dtype
|
||||||
|
res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type}
|
||||||
|
elif key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info:
|
||||||
targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
|
targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
|
||||||
tensors = self.load_multi(key, targets, device=device)
|
tensors = self.load_multi(key, targets, device=device)
|
||||||
weight = tensors[".ffn_gate_inp.weight"]
|
weight = tensors[".ffn_gate_inp.weight"]
|
||||||
|
@ -116,8 +123,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
|
||||||
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
|
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid weight type")
|
raise ValueError("Invalid weight type")
|
||||||
self.orig_module.weight = self.orig_module.weight.to(device)
|
self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device))
|
||||||
self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device)
|
self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device))
|
||||||
|
|
||||||
def unload(self):
|
def unload(self):
|
||||||
if self.weight is not None:
|
if self.weight is not None:
|
||||||
|
|
|
@ -76,7 +76,13 @@ class KLinearBase(ABC):
|
||||||
keys = [self.key]
|
keys = [self.key]
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key + ".weight" in self.gguf_loader.tensor_file_map:
|
if self.gguf_loader.safetensor_loader is not None:
|
||||||
|
# using safetensor_loader
|
||||||
|
tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight')
|
||||||
|
weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv')
|
||||||
|
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
|
||||||
|
|
||||||
|
elif key + ".weight" in self.gguf_loader.tensor_file_map:
|
||||||
if key + ".bias" in self.gguf_loader.tensor_file_map:
|
if key + ".bias" in self.gguf_loader.tensor_file_map:
|
||||||
tensors = self.load_multi(key, ["weight", "bias"], device=device)
|
tensors = self.load_multi(key, ["weight", "bias"], device=device)
|
||||||
tensor = tensors["weight"]
|
tensor = tensors["weight"]
|
||||||
|
@ -166,6 +172,8 @@ class KLinearTorch(KLinearBase):
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
class KLinearFP8(KLinearBase):
|
class KLinearFP8(KLinearBase):
|
||||||
|
# this kernel requires special handling for weight
|
||||||
|
# Please load the weight file downloaded from KVCache.AI
|
||||||
marlin_q_w: torch.Tensor
|
marlin_q_w: torch.Tensor
|
||||||
marlin_s: torch.Tensor
|
marlin_s: torch.Tensor
|
||||||
g_idx: torch.Tensor
|
g_idx: torch.Tensor
|
||||||
|
@ -191,26 +199,20 @@ class KLinearFP8(KLinearBase):
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x = x.to(self.device)
|
x = x.to(self.device)
|
||||||
orig_shape = list(x.shape)
|
|
||||||
orig_dtype = x.dtype
|
orig_dtype = x.dtype
|
||||||
x = x.reshape(-1, orig_shape[-1])
|
|
||||||
x_quantized, scale_x = act_quant(x, self.block_size)
|
x_quantized, scale_x = act_quant(x, self.block_size)
|
||||||
y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight.scale)
|
y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight_scale_inv)
|
||||||
if self.bias is not None:
|
return y.to(dtype=orig_dtype)
|
||||||
y += self.bias
|
|
||||||
return y.to(orig_dtype).reshape(orig_shape)
|
|
||||||
|
|
||||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||||
if device is None: device = self.device
|
if device is None: device = self.device
|
||||||
if w is None:
|
if w is None:
|
||||||
w = self.load_weight(device=device)
|
w = self.load_weight(device=device)
|
||||||
if isinstance(w, nn.Parameter):
|
### TODO fit weight_inv format
|
||||||
self.weight = w.to(device)
|
if isinstance(w, tuple):
|
||||||
self.has_bias = False
|
|
||||||
elif isinstance(w, tuple):
|
|
||||||
self.weight = w[0].to(device)
|
self.weight = w[0].to(device)
|
||||||
self.bias = w[1].to(device)
|
self.weight_scale_inv = w[1].to(device)
|
||||||
self.has_bias = True
|
self.has_bias = False
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid weight type")
|
raise ValueError("Invalid weight type")
|
||||||
self.weight = self.weight.to(device)
|
self.weight = self.weight.to(device)
|
||||||
|
@ -425,7 +427,8 @@ class KLinearCPUInfer(KLinearBase):
|
||||||
LINEAR_MAP = {
|
LINEAR_MAP = {
|
||||||
"KLinearMarlin": KLinearMarlin,
|
"KLinearMarlin": KLinearMarlin,
|
||||||
"KLinearTorch": KLinearTorch,
|
"KLinearTorch": KLinearTorch,
|
||||||
"KLinearCPUInfer": KLinearCPUInfer
|
"KLinearCPUInfer": KLinearCPUInfer,
|
||||||
|
"KLinearFP8": KLinearFP8,
|
||||||
}
|
}
|
||||||
|
|
||||||
class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
||||||
|
@ -472,10 +475,11 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.mode == InferenceState.PREFILL:
|
if self.mode == InferenceState.PREFILL:
|
||||||
assert self.prefill_linear is not None, "cpu linear is not initialized"
|
assert self.prefill_linear is not None, "cpu linear is not initialized"
|
||||||
return self.prefill_linear.forward(x)
|
y = self.prefill_linear.forward(x)
|
||||||
else:
|
else:
|
||||||
assert self.generate_linear is not None, "gpu linear is not initialized"
|
assert self.generate_linear is not None, "gpu linear is not initialized"
|
||||||
return self.generate_linear.forward(x)
|
y = self.generate_linear.forward(x)
|
||||||
|
return y
|
||||||
|
|
||||||
def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE):
|
def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE):
|
||||||
if not mode:
|
if not mode:
|
||||||
|
|
|
@ -0,0 +1,63 @@
|
||||||
|
- match:
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
generate_op: "KLinearFP8"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.gate.KMoEGate
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "cuda"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KExpertsCPU"
|
||||||
|
out_device: "cuda"
|
||||||
|
recursive: False # don't recursively inject submodules of this module
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
name: "^model$"
|
||||||
|
replace:
|
||||||
|
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||||
|
kwargs:
|
||||||
|
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||||
|
- match:
|
||||||
|
name: "^model.embed_tokens"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cpu"
|
||||||
|
prefill_device: "cpu"
|
|
@ -3,7 +3,7 @@ import torch.nn.functional as F
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import pytest
|
import pytest
|
||||||
from typing import Tuple, Optional, Literal
|
from typing import Tuple, Optional, Literal
|
||||||
|
import time
|
||||||
# use dir path
|
# use dir path
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
@ -56,18 +56,61 @@ def test_fp8_gemm_vs_torch_matmul_load():
|
||||||
print(f"weight_dequantized: {weight_dequantized.shape}")
|
print(f"weight_dequantized: {weight_dequantized.shape}")
|
||||||
N, K = weight_dequantized.shape
|
N, K = weight_dequantized.shape
|
||||||
M = 64
|
M = 64
|
||||||
x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
|
x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda')
|
||||||
x_quantized, scale_x = act_quant(x, block_size)
|
x_quantized, scale_x = act_quant(x, block_size)
|
||||||
|
|
||||||
# Test case 1: quantized x matmal with undequantized weight
|
# Test case 1: quantized x matmal with undequantized weight
|
||||||
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
|
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
|
||||||
print(f"result_fp8_gemm:\n {result_fp8_gemm}")
|
print(f"result_fp8_gemm:\n {result_fp8_gemm}")
|
||||||
|
print(f"dtype {result_fp8_gemm.dtype}")
|
||||||
|
|
||||||
# Perform torch.matmul using the original floating point tensors
|
# Perform torch.matmul using the original floating point tensors
|
||||||
result_torch_matmul = torch.matmul(x, weight_dequantized.to(torch.bfloat16).T)
|
result_torch_matmul = torch.matmul(x, weight_dequantized.to(torch.bfloat16).T)
|
||||||
print(f"result_torch_matmul:\n {result_torch_matmul}")
|
print(f"result_torch_matmul:\n {result_torch_matmul}")
|
||||||
|
|
||||||
|
def test_fp8_gemm_tplops():
|
||||||
|
file_path = "/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors"
|
||||||
|
with safe_open(file_path, framework="pt", device=0) as f:
|
||||||
|
weight = f.get_tensor("model.layers.0.mlp.down_proj.weight")
|
||||||
|
scale = f.get_tensor("model.layers.0.mlp.down_proj.weight_scale_inv")
|
||||||
|
|
||||||
|
# weight_dequant
|
||||||
|
weight_dequantized = weight_dequant(weight, scale)
|
||||||
|
print(f"weight_dequantized: {weight_dequantized.shape}")
|
||||||
|
N, K = weight_dequantized.shape
|
||||||
|
M = 6400
|
||||||
|
x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda')
|
||||||
|
# x_quantized, scale_x = act_quant(x, block_size)
|
||||||
|
|
||||||
|
# Calculate time for 1000 fp8_gemm
|
||||||
|
i = 10
|
||||||
|
flops_per_gemm = 2 * M * N * K
|
||||||
|
total_flops = i * flops_per_gemm
|
||||||
|
|
||||||
|
x_quantized, scale_x = act_quant(x, block_size)
|
||||||
|
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
|
||||||
|
x_quantized, scale_x = act_quant(x, block_size)
|
||||||
|
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
|
||||||
|
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
for i in range(i):
|
||||||
|
x_quantized, scale_x = act_quant(x, block_size)
|
||||||
|
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
t1 = time.time()
|
||||||
|
|
||||||
|
total_time = t1 - t0
|
||||||
|
tflops = total_flops / total_time / 1e12
|
||||||
|
print(f"total_time: {total_time}")
|
||||||
|
print(f"tflops: {tflops}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_fp8_gemm_vs_torch_matmul()
|
test_fp8_gemm_vs_torch_matmul()
|
||||||
test_fp8_gemm_vs_torch_matmul_load()
|
test_fp8_gemm_vs_torch_matmul_load()
|
||||||
|
test_fp8_gemm_tplops()
|
||||||
|
|
|
@ -25,6 +25,7 @@ import os
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
import torch
|
import torch
|
||||||
import KTransformersOps
|
import KTransformersOps
|
||||||
|
from .custom_loader import SafeTensorLoader
|
||||||
|
|
||||||
class GGMLQuantizationType(IntEnum):
|
class GGMLQuantizationType(IntEnum):
|
||||||
F32 = 0
|
F32 = 0
|
||||||
|
@ -168,6 +169,7 @@ class GGUFLoader:
|
||||||
gguf_path: str
|
gguf_path: str
|
||||||
tensor_file_map: dict # {tensor_name: tensor_file_path}
|
tensor_file_map: dict # {tensor_name: tensor_file_path}
|
||||||
gguf_file_meta: dict
|
gguf_file_meta: dict
|
||||||
|
safetensor_loader: SafeTensorLoader
|
||||||
def __init__(self, gguf_path: str):
|
def __init__(self, gguf_path: str):
|
||||||
# Check dir exist
|
# Check dir exist
|
||||||
if not os.path.exists(gguf_path):
|
if not os.path.exists(gguf_path):
|
||||||
|
@ -175,6 +177,8 @@ class GGUFLoader:
|
||||||
if os.path.isfile(gguf_path):
|
if os.path.isfile(gguf_path):
|
||||||
gguf_path = os.path.dirname(gguf_path)
|
gguf_path = os.path.dirname(gguf_path)
|
||||||
|
|
||||||
|
self.safetensor_loader = None
|
||||||
|
|
||||||
self.tensor_info = {}
|
self.tensor_info = {}
|
||||||
self.gguf_path = gguf_path
|
self.gguf_path = gguf_path
|
||||||
self.tensor_file_map = {}
|
self.tensor_file_map = {}
|
||||||
|
@ -182,6 +186,12 @@ class GGUFLoader:
|
||||||
self.gguf_file_meta = {}
|
self.gguf_file_meta = {}
|
||||||
self.tensor_device_map = {}
|
self.tensor_device_map = {}
|
||||||
|
|
||||||
|
# I know this is ugly, but I don't want to change the original code too much
|
||||||
|
# TODO: merge gguf load and other loads.
|
||||||
|
safetensor_loader = SafeTensorLoader(gguf_path)
|
||||||
|
if safetensor_loader.tensor_file_map:
|
||||||
|
self.safetensor_loader = safetensor_loader
|
||||||
|
return
|
||||||
# Walk through all the .gguf files in the directory
|
# Walk through all the .gguf files in the directory
|
||||||
found_gguf = False
|
found_gguf = False
|
||||||
for root, dirs, files in os.walk(gguf_path):
|
for root, dirs, files in os.walk(gguf_path):
|
||||||
|
@ -288,6 +298,13 @@ class GGUFLoader:
|
||||||
itemsize = int(np.empty([], dtype = item_type).itemsize)
|
itemsize = int(np.empty([], dtype = item_type).itemsize)
|
||||||
return mmap_data[offset : offset + itemsize * item_count]
|
return mmap_data[offset : offset + itemsize * item_count]
|
||||||
|
|
||||||
|
def get_undequanted_tensor_and_ggml_type(self, name):
|
||||||
|
t = self.tensor_info[name]
|
||||||
|
data = self.get_mmap_tensor(name)
|
||||||
|
ggml_type = t["ggml_type"]
|
||||||
|
data = torch.from_numpy(data)
|
||||||
|
return data, ggml_type
|
||||||
|
|
||||||
def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "gpu")->torch.Tensor:
|
def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "gpu")->torch.Tensor:
|
||||||
t = self.tensor_info[name]
|
t = self.tensor_info[name]
|
||||||
if device.lower() == "cpu":
|
if device.lower() == "cpu":
|
||||||
|
|
86
ktransformers/util/custom_loader.py
Normal file
86
ktransformers/util/custom_loader.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
import struct
|
||||||
|
import warnings
|
||||||
|
import numpy as np
|
||||||
|
import re
|
||||||
|
import numpy.typing as npt
|
||||||
|
from typing import Sequence
|
||||||
|
import os
|
||||||
|
from enum import IntEnum
|
||||||
|
import torch
|
||||||
|
import KTransformersOps
|
||||||
|
from safetensors import safe_open
|
||||||
|
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
class SafeTensorLoader:
|
||||||
|
tensor_file_map = {}
|
||||||
|
tensor_type_map = {}
|
||||||
|
file_handle_map = {}
|
||||||
|
|
||||||
|
def __init__(self, file_path: str):
|
||||||
|
self.__load_tensor_file_map(file_path)
|
||||||
|
|
||||||
|
def __load_tensor_file_map(self, file_path: str):
|
||||||
|
# 处理传入路径,确保是文件夹路径
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"Path not found: {file_path}")
|
||||||
|
if os.path.isfile(file_path):
|
||||||
|
folder_path = os.path.dirname(file_path)
|
||||||
|
else:
|
||||||
|
folder_path = file_path
|
||||||
|
|
||||||
|
found_safetensor = False
|
||||||
|
for root, _, files in os.walk(folder_path):
|
||||||
|
files = sorted(files)
|
||||||
|
for file in files:
|
||||||
|
if file.endswith(".safetensors"):
|
||||||
|
found_safetensor = True
|
||||||
|
file_path = os.path.join(root, file)
|
||||||
|
if file not in self.file_handle_map:
|
||||||
|
try:
|
||||||
|
handle = safe_open(file_path, framework="pt")
|
||||||
|
self.file_handle_map[file] = handle
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error opening Safetensor file {file_path}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
f = self.file_handle_map.get(file)
|
||||||
|
if f is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
for key in f.keys():
|
||||||
|
self.tensor_file_map[key] = file
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading Safetensor file {file_path}: {e}")
|
||||||
|
|
||||||
|
# if not found_safetensor:
|
||||||
|
# raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
|
||||||
|
|
||||||
|
def load_tensor(self, key: str, device: str="cpu"):
|
||||||
|
if key not in self.tensor_file_map:
|
||||||
|
raise KeyError(f"Key {key} not found in Safetensor files")
|
||||||
|
file = self.tensor_file_map[key]
|
||||||
|
f = self.file_handle_map.get(file)
|
||||||
|
if f is None:
|
||||||
|
raise FileNotFoundError(f"File {file} not found in Safetensor files")
|
||||||
|
tensor = f.get_tensor(key)
|
||||||
|
return tensor.to(device)
|
||||||
|
|
||||||
|
def close_all_handles(self):
|
||||||
|
for handle in self.file_handle_map.values():
|
||||||
|
handle.close()
|
||||||
|
self.file_handle_map.clear()
|
||||||
|
|
||||||
|
def load_dequantized_tensor(self, key:str, device: str="cpu"):
|
||||||
|
if key not in self.tensor_file_map:
|
||||||
|
raise KeyError(f"Key {key} not found in Safetensor files")
|
||||||
|
file = self.tensor_file_map[key]
|
||||||
|
f = self.file_handle_map.get(file)
|
||||||
|
if f is None:
|
||||||
|
raise FileNotFoundError(f"File {file} not found in Safetensor files")
|
||||||
|
tensor = f.get_tensor(key).to(device)
|
||||||
|
if key.endswith(".weight"):
|
||||||
|
if key[:-7] + ".weight_scale_inv" in self.tensor_file_map:
|
||||||
|
weight_scale_inv = f.get_tensor(key[:-7] + ".weight_scale_inv").to(device)
|
||||||
|
tensor = weight_dequant(tensor, weight_scale_inv)
|
||||||
|
return tensor.to(device)
|
|
@ -66,12 +66,23 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
|
||||||
for name, param in local_state.items():
|
for name, param in local_state.items():
|
||||||
key = prefix + name
|
key = prefix + name
|
||||||
translated_key = translate_name_to_gguf(key)
|
translated_key = translate_name_to_gguf(key)
|
||||||
if translated_key in gguf_loader.tensor_file_map:
|
|
||||||
|
# TODO: Merge all loader.
|
||||||
|
# I know this is ugly but lets do it for now.
|
||||||
|
if gguf_loader.safetensor_loader is not None:
|
||||||
|
load_dequantized_tensor = gguf_loader.safetensor_loader.load_dequantized_tensor
|
||||||
|
tensor_file_map = gguf_loader.safetensor_loader.tensor_file_map
|
||||||
|
else:
|
||||||
|
load_dequantized_tensor = gguf_loader.load_gguf_tensor
|
||||||
|
tensor_file_map = gguf_loader.tensor_file_map
|
||||||
|
|
||||||
|
if translated_key in tensor_file_map:
|
||||||
target_dtype = torch.get_default_dtype()
|
target_dtype = torch.get_default_dtype()
|
||||||
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
|
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
|
||||||
print(f"loading {translated_key} to {device}")
|
print(f"loading {translated_key} to {device}")
|
||||||
torch.cuda.empty_cache() # To fit in 16G VRAM. By "wkGCaSS - 知乎 https://zhuanlan.zhihu.com/p/25491611225"
|
torch.cuda.empty_cache() # To fit in 16G VRAM. By "wkGCaSS - 知乎 https://zhuanlan.zhihu.com/p/25491611225"
|
||||||
weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
|
# weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
|
||||||
|
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
|
||||||
set_param(module, name, weights)
|
set_param(module, name, weights)
|
||||||
del weights
|
del weights
|
||||||
else:
|
else:
|
||||||
|
|
214
merge_tensors/merge_safetensor_gguf.py
Normal file
214
merge_tensors/merge_safetensor_gguf.py
Normal file
|
@ -0,0 +1,214 @@
|
||||||
|
# this script targets to merge the fp8 safe tensor and the gguf quantized tensors.
|
||||||
|
|
||||||
|
import os
|
||||||
|
# insert the path of the project
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, "/home/azure/ktransformers")
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf
|
||||||
|
from safetensors import safe_open
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
import re
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
def read_safetensor_keys_from_folder(folder_path)->dict:
|
||||||
|
"""
|
||||||
|
:param folder_path: folder path
|
||||||
|
:return: key_to_file_map
|
||||||
|
"""
|
||||||
|
# check if the folder path is exist
|
||||||
|
if not os.path.exists(folder_path):
|
||||||
|
raise FileNotFoundError(f"GGUF dir not found: {folder_path}")
|
||||||
|
if os.path.isfile(folder_path):
|
||||||
|
folder_path = os.path.dirname(folder_path)
|
||||||
|
|
||||||
|
key_to_file_map = {}
|
||||||
|
|
||||||
|
found_safetensor = False
|
||||||
|
for root, dirs, files in os.walk(folder_path):
|
||||||
|
# sort files
|
||||||
|
files = sorted(files)
|
||||||
|
for file in files:
|
||||||
|
if file.endswith(".safetensors"):
|
||||||
|
found_safetensor = True
|
||||||
|
file_path = os.path.join(root, file)
|
||||||
|
try:
|
||||||
|
with safe_open(file_path, framework="pt") as f:
|
||||||
|
for key in f.keys():
|
||||||
|
if "model.layers.61" in key:
|
||||||
|
# skip MTP layer
|
||||||
|
continue
|
||||||
|
# try:
|
||||||
|
# if int(key.split('.')[2]) > 4:
|
||||||
|
# continue
|
||||||
|
# except:
|
||||||
|
# pass
|
||||||
|
key_to_file_map[key] = file_path
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading Safetensor file {file_path}: {e}")
|
||||||
|
|
||||||
|
if not found_safetensor:
|
||||||
|
raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
|
||||||
|
|
||||||
|
return key_to_file_map
|
||||||
|
|
||||||
|
tensor_from_gguf = [] # todo: add keys in gguf that should be used in the final tensor
|
||||||
|
|
||||||
|
def translate_name(name:str)->str:
|
||||||
|
"""
|
||||||
|
:param name: name of the tensor
|
||||||
|
:return: translated name
|
||||||
|
"""
|
||||||
|
name = translate_name_to_gguf(name)
|
||||||
|
name = name.replace(".up_proj.", ".ffn_up_exps.")
|
||||||
|
name = name.replace(".down_proj.", ".ffn_down_exps.")
|
||||||
|
name = name.replace(".gate_proj.", ".ffn_gate_exps.")
|
||||||
|
name = name.replace(".ffn_gate_inp.e_score_correction_bias", ".exp_probs_b.bias")
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def combine_tensor_sources(safetensor_path:str, gguf_path:str):
|
||||||
|
gguf_loader = GGUFLoader(gguf_path)
|
||||||
|
gguf_tensor_file_map = gguf_loader.tensor_file_map
|
||||||
|
safetensor_tensor_file_map = read_safetensor_keys_from_folder(safetensor_path)
|
||||||
|
|
||||||
|
# build a map for the key to the tensor
|
||||||
|
# according to the key, we can get the tensor from the file
|
||||||
|
|
||||||
|
target_tensor_map = {}
|
||||||
|
for key in safetensor_tensor_file_map.keys():
|
||||||
|
# for all experts, we use the gguf tensor
|
||||||
|
if ".mlp.experts." in key:
|
||||||
|
if '.weight_scale_inv' in key:
|
||||||
|
continue
|
||||||
|
key = '.'.join(key.split('.')[:5]+key.split('.')[-2:])
|
||||||
|
translated_key = translate_name(key)
|
||||||
|
target_tensor_map[key] = gguf_tensor_file_map[translated_key]
|
||||||
|
continue
|
||||||
|
|
||||||
|
if any(target_key in key for target_key in tensor_from_gguf):
|
||||||
|
target_tensor_map[key] = gguf_tensor_file_map[translate_name(key)]
|
||||||
|
else:
|
||||||
|
target_tensor_map[key] = safetensor_tensor_file_map[key]
|
||||||
|
|
||||||
|
return target_tensor_map, gguf_loader
|
||||||
|
|
||||||
|
def write_combined_tensor(target_tensor_map: dict, output_path: str, gguf_loader: GGUFLoader):
|
||||||
|
# Ensure output directory exists
|
||||||
|
os.makedirs(output_path, exist_ok=True)
|
||||||
|
|
||||||
|
# Cache for safetensor file handles and GGUF loaders
|
||||||
|
safetensors_cache = {}
|
||||||
|
gguf_cache = {}
|
||||||
|
|
||||||
|
# Group tensors by layer
|
||||||
|
layer_groups = defaultdict(list)
|
||||||
|
non_layer_keys = []
|
||||||
|
layer_pattern = re.compile(r'\.layers\.(\d+)\.')
|
||||||
|
|
||||||
|
for key in target_tensor_map:
|
||||||
|
match = layer_pattern.search(key)
|
||||||
|
if match:
|
||||||
|
layer_num = int(match.group(1))
|
||||||
|
layer_groups[layer_num].append(key)
|
||||||
|
else:
|
||||||
|
non_layer_keys.append(key)
|
||||||
|
|
||||||
|
# Calculate total shards
|
||||||
|
total_shards = len(layer_groups) + (1 if non_layer_keys else 0) - 1
|
||||||
|
if total_shards == 0:
|
||||||
|
raise ValueError("No tensors to save")
|
||||||
|
|
||||||
|
shard_idx = 0
|
||||||
|
|
||||||
|
# Save non-layer tensors to the first shard if they exist
|
||||||
|
if non_layer_keys:
|
||||||
|
tensors = {}
|
||||||
|
for key in non_layer_keys:
|
||||||
|
file_path = target_tensor_map[key]
|
||||||
|
tensor = None
|
||||||
|
ggml_type = None
|
||||||
|
if file_path.endswith('.safetensors'):
|
||||||
|
if file_path not in safetensors_cache:
|
||||||
|
safetensors_cache[file_path] = safe_open(file_path, framework='pt')
|
||||||
|
f = safetensors_cache[file_path]
|
||||||
|
tensor = f.get_tensor(key)
|
||||||
|
elif file_path.endswith('.gguf'):
|
||||||
|
gguf_name = translate_name(key)
|
||||||
|
tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported file format: {file_path}")
|
||||||
|
tensors[translate_name(key)] = tensor
|
||||||
|
if ggml_type:
|
||||||
|
ggml_type = torch.tensor(ggml_type)
|
||||||
|
ggml_key = translate_name(key)[:-7] + ".ggml_type" if translate_name(key).endswith(".weight") else translate_name(key) + ".ggml_type"
|
||||||
|
tensors[ggml_key] = ggml_type
|
||||||
|
|
||||||
|
output_file = os.path.join(output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors")
|
||||||
|
print(f"Saving non-layer tensors to {output_file}")
|
||||||
|
save_file(tensors, output_file)
|
||||||
|
print(tensors.keys())
|
||||||
|
|
||||||
|
shard_idx += 1
|
||||||
|
|
||||||
|
# Save each layer's tensors to subsequent shards
|
||||||
|
for layer_num in sorted(layer_groups.keys()):
|
||||||
|
layer_keys = layer_groups[layer_num]
|
||||||
|
tensors = {}
|
||||||
|
for key in layer_keys:
|
||||||
|
file_path = target_tensor_map[key]
|
||||||
|
tensor = None
|
||||||
|
ggml_type = None
|
||||||
|
if file_path.endswith('.safetensors'):
|
||||||
|
if file_path not in safetensors_cache:
|
||||||
|
safetensors_cache[file_path] = safe_open(file_path, framework='pt')
|
||||||
|
f = safetensors_cache[file_path]
|
||||||
|
tensor = f.get_tensor(key)
|
||||||
|
tensor_info = tensor.shape
|
||||||
|
elif file_path.endswith('.gguf'):
|
||||||
|
gguf_name = translate_name(key)
|
||||||
|
tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name)
|
||||||
|
# tensor_info = gguf_loader.tensor_info[gguf_name]
|
||||||
|
# ggml_type = gguf_loader.tensor_info[gguf_name]['ggml_type']
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported file format: {file_path}")
|
||||||
|
tensors[translate_name(key)] = tensor
|
||||||
|
if ggml_type:
|
||||||
|
ggml_type = torch.tensor(ggml_type)
|
||||||
|
ggml_key = translate_name(key)[:-7] + ".ggml_type" if translate_name(key).endswith(".weight") else translate_name(key) + ".ggml_type"
|
||||||
|
tensors[ggml_key] = ggml_type
|
||||||
|
|
||||||
|
output_file = os.path.join(output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors")
|
||||||
|
print(f"Saving layer {layer_num} to {output_file}")
|
||||||
|
print(tensors.keys())
|
||||||
|
save_file(tensors, output_file)
|
||||||
|
shard_idx += 1
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 创建命令行参数解析器
|
||||||
|
parser = argparse.ArgumentParser(description="Read parameters from Safetensor and GGUF files")
|
||||||
|
parser.add_argument("--safetensor_path", type=str, help="Path to the Safetensor file", default="/mnt/data/model/DeepSeek-V3")
|
||||||
|
parser.add_argument("--gguf_path", type=str, help="Path to the GGUF file", default="/mnt/data/model/DeepseekV3-q4km-gguf")
|
||||||
|
parser.add_argument("--output_path", type=str, help="Path to the output file", default="/mnt/data/model/ktrans-safetensors/DeepSeek-V3-q4km-fp8")
|
||||||
|
|
||||||
|
# print all the arguments
|
||||||
|
print("All the arguments:")
|
||||||
|
print(parser.parse_args())
|
||||||
|
|
||||||
|
# 解析命令行参数
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
safetensor_path = args.safetensor_path
|
||||||
|
gguf_path = args.gguf_path
|
||||||
|
output_path = args.output_path
|
||||||
|
|
||||||
|
target_tensor_map, gguf_loader = combine_tensor_sources(safetensor_path, gguf_path)
|
||||||
|
write_combined_tensor(target_tensor_map, output_path, gguf_loader)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Add table
Add a link
Reference in a new issue