mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
Add data loader to read special weights for fp8; Add special weight process script
This commit is contained in:
parent
7b7c6a657d
commit
581a524f65
10 changed files with 481 additions and 26 deletions
|
@ -1,3 +1,4 @@
|
|||
# Adopted from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
|
|
@ -245,7 +245,16 @@ class KExpertsCPU(KExpertsBase):
|
|||
down_type = None
|
||||
|
||||
for key in keys:
|
||||
if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
|
||||
if self.gguf_loader.safetensor_loader is not None:
|
||||
# using a temp ugly way to temprary load the tensor
|
||||
gate = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.weight").numpy()
|
||||
up = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.weight").numpy()
|
||||
down = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.weight").numpy()
|
||||
gate_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.ggml_type").item()
|
||||
up_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.ggml_type").item()
|
||||
down_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.ggml_type").item()
|
||||
|
||||
elif key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
|
||||
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
|
||||
up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
|
||||
down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
|
||||
|
|
|
@ -67,7 +67,14 @@ class KMoEGateBase(ABC):
|
|||
|
||||
for key in keys:
|
||||
key = ".".join(key.split(".")[:-1])
|
||||
if key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info:
|
||||
if self.gguf_loader.safetensor_loader is not None:
|
||||
targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
|
||||
weight = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_inp.weight")
|
||||
e_score_correction_bias = self.gguf_loader.safetensor_loader.load_tensor(key + ".exp_probs_b.bias")
|
||||
weight_type = weight.dtype
|
||||
e_score_correction_bias_type = e_score_correction_bias.dtype
|
||||
res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type}
|
||||
elif key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info:
|
||||
targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
|
||||
tensors = self.load_multi(key, targets, device=device)
|
||||
weight = tensors[".ffn_gate_inp.weight"]
|
||||
|
@ -116,8 +123,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
|
|||
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
|
||||
else:
|
||||
raise ValueError("Invalid weight type")
|
||||
self.orig_module.weight = self.orig_module.weight.to(device)
|
||||
self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device)
|
||||
self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device))
|
||||
self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device))
|
||||
|
||||
def unload(self):
|
||||
if self.weight is not None:
|
||||
|
|
|
@ -76,7 +76,13 @@ class KLinearBase(ABC):
|
|||
keys = [self.key]
|
||||
|
||||
for key in keys:
|
||||
if key + ".weight" in self.gguf_loader.tensor_file_map:
|
||||
if self.gguf_loader.safetensor_loader is not None:
|
||||
# using safetensor_loader
|
||||
tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight')
|
||||
weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv')
|
||||
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
|
||||
|
||||
elif key + ".weight" in self.gguf_loader.tensor_file_map:
|
||||
if key + ".bias" in self.gguf_loader.tensor_file_map:
|
||||
tensors = self.load_multi(key, ["weight", "bias"], device=device)
|
||||
tensor = tensors["weight"]
|
||||
|
@ -166,6 +172,8 @@ class KLinearTorch(KLinearBase):
|
|||
self.bias = None
|
||||
|
||||
class KLinearFP8(KLinearBase):
|
||||
# this kernel requires special handling for weight
|
||||
# Please load the weight file downloaded from KVCache.AI
|
||||
marlin_q_w: torch.Tensor
|
||||
marlin_s: torch.Tensor
|
||||
g_idx: torch.Tensor
|
||||
|
@ -191,26 +199,20 @@ class KLinearFP8(KLinearBase):
|
|||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.to(self.device)
|
||||
orig_shape = list(x.shape)
|
||||
orig_dtype = x.dtype
|
||||
x = x.reshape(-1, orig_shape[-1])
|
||||
x_quantized, scale_x = act_quant(x, self.block_size)
|
||||
y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight.scale)
|
||||
if self.bias is not None:
|
||||
y += self.bias
|
||||
return y.to(orig_dtype).reshape(orig_shape)
|
||||
y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight_scale_inv)
|
||||
return y.to(dtype=orig_dtype)
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||
if device is None: device = self.device
|
||||
if w is None:
|
||||
w = self.load_weight(device=device)
|
||||
if isinstance(w, nn.Parameter):
|
||||
self.weight = w.to(device)
|
||||
self.has_bias = False
|
||||
elif isinstance(w, tuple):
|
||||
### TODO fit weight_inv format
|
||||
if isinstance(w, tuple):
|
||||
self.weight = w[0].to(device)
|
||||
self.bias = w[1].to(device)
|
||||
self.has_bias = True
|
||||
self.weight_scale_inv = w[1].to(device)
|
||||
self.has_bias = False
|
||||
else:
|
||||
raise ValueError("Invalid weight type")
|
||||
self.weight = self.weight.to(device)
|
||||
|
@ -425,7 +427,8 @@ class KLinearCPUInfer(KLinearBase):
|
|||
LINEAR_MAP = {
|
||||
"KLinearMarlin": KLinearMarlin,
|
||||
"KLinearTorch": KLinearTorch,
|
||||
"KLinearCPUInfer": KLinearCPUInfer
|
||||
"KLinearCPUInfer": KLinearCPUInfer,
|
||||
"KLinearFP8": KLinearFP8,
|
||||
}
|
||||
|
||||
class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
||||
|
@ -472,10 +475,11 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
|||
def forward(self, x):
|
||||
if self.mode == InferenceState.PREFILL:
|
||||
assert self.prefill_linear is not None, "cpu linear is not initialized"
|
||||
return self.prefill_linear.forward(x)
|
||||
y = self.prefill_linear.forward(x)
|
||||
else:
|
||||
assert self.generate_linear is not None, "gpu linear is not initialized"
|
||||
return self.generate_linear.forward(x)
|
||||
y = self.generate_linear.forward(x)
|
||||
return y
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE):
|
||||
if not mode:
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
- match:
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearFP8"
|
||||
prefill_op: "KLinearTorch"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||
replace:
|
||||
class: ktransformers.operators.gate.KMoEGate
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KExpertsCPU"
|
||||
out_device: "cuda"
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
|
@ -3,7 +3,7 @@ import torch.nn.functional as F
|
|||
from typing import Optional
|
||||
import pytest
|
||||
from typing import Tuple, Optional, Literal
|
||||
|
||||
import time
|
||||
# use dir path
|
||||
import os
|
||||
import sys
|
||||
|
@ -56,18 +56,61 @@ def test_fp8_gemm_vs_torch_matmul_load():
|
|||
print(f"weight_dequantized: {weight_dequantized.shape}")
|
||||
N, K = weight_dequantized.shape
|
||||
M = 64
|
||||
x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
|
||||
x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda')
|
||||
x_quantized, scale_x = act_quant(x, block_size)
|
||||
|
||||
# Test case 1: quantized x matmal with undequantized weight
|
||||
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
|
||||
print(f"result_fp8_gemm:\n {result_fp8_gemm}")
|
||||
print(f"dtype {result_fp8_gemm.dtype}")
|
||||
|
||||
# Perform torch.matmul using the original floating point tensors
|
||||
result_torch_matmul = torch.matmul(x, weight_dequantized.to(torch.bfloat16).T)
|
||||
print(f"result_torch_matmul:\n {result_torch_matmul}")
|
||||
|
||||
def test_fp8_gemm_tplops():
|
||||
file_path = "/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors"
|
||||
with safe_open(file_path, framework="pt", device=0) as f:
|
||||
weight = f.get_tensor("model.layers.0.mlp.down_proj.weight")
|
||||
scale = f.get_tensor("model.layers.0.mlp.down_proj.weight_scale_inv")
|
||||
|
||||
# weight_dequant
|
||||
weight_dequantized = weight_dequant(weight, scale)
|
||||
print(f"weight_dequantized: {weight_dequantized.shape}")
|
||||
N, K = weight_dequantized.shape
|
||||
M = 6400
|
||||
x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda')
|
||||
# x_quantized, scale_x = act_quant(x, block_size)
|
||||
|
||||
# Calculate time for 1000 fp8_gemm
|
||||
i = 10
|
||||
flops_per_gemm = 2 * M * N * K
|
||||
total_flops = i * flops_per_gemm
|
||||
|
||||
x_quantized, scale_x = act_quant(x, block_size)
|
||||
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
|
||||
x_quantized, scale_x = act_quant(x, block_size)
|
||||
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
|
||||
|
||||
|
||||
t0 = time.time()
|
||||
torch.cuda.synchronize()
|
||||
for i in range(i):
|
||||
x_quantized, scale_x = act_quant(x, block_size)
|
||||
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.time()
|
||||
|
||||
total_time = t1 - t0
|
||||
tflops = total_flops / total_time / 1e12
|
||||
print(f"total_time: {total_time}")
|
||||
print(f"tflops: {tflops}")
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_fp8_gemm_vs_torch_matmul()
|
||||
test_fp8_gemm_vs_torch_matmul_load()
|
||||
test_fp8_gemm_tplops()
|
||||
|
|
@ -25,6 +25,7 @@ import os
|
|||
from enum import IntEnum
|
||||
import torch
|
||||
import KTransformersOps
|
||||
from .custom_loader import SafeTensorLoader
|
||||
|
||||
class GGMLQuantizationType(IntEnum):
|
||||
F32 = 0
|
||||
|
@ -168,6 +169,7 @@ class GGUFLoader:
|
|||
gguf_path: str
|
||||
tensor_file_map: dict # {tensor_name: tensor_file_path}
|
||||
gguf_file_meta: dict
|
||||
safetensor_loader: SafeTensorLoader
|
||||
def __init__(self, gguf_path: str):
|
||||
# Check dir exist
|
||||
if not os.path.exists(gguf_path):
|
||||
|
@ -175,6 +177,8 @@ class GGUFLoader:
|
|||
if os.path.isfile(gguf_path):
|
||||
gguf_path = os.path.dirname(gguf_path)
|
||||
|
||||
self.safetensor_loader = None
|
||||
|
||||
self.tensor_info = {}
|
||||
self.gguf_path = gguf_path
|
||||
self.tensor_file_map = {}
|
||||
|
@ -182,6 +186,12 @@ class GGUFLoader:
|
|||
self.gguf_file_meta = {}
|
||||
self.tensor_device_map = {}
|
||||
|
||||
# I know this is ugly, but I don't want to change the original code too much
|
||||
# TODO: merge gguf load and other loads.
|
||||
safetensor_loader = SafeTensorLoader(gguf_path)
|
||||
if safetensor_loader.tensor_file_map:
|
||||
self.safetensor_loader = safetensor_loader
|
||||
return
|
||||
# Walk through all the .gguf files in the directory
|
||||
found_gguf = False
|
||||
for root, dirs, files in os.walk(gguf_path):
|
||||
|
@ -288,6 +298,13 @@ class GGUFLoader:
|
|||
itemsize = int(np.empty([], dtype = item_type).itemsize)
|
||||
return mmap_data[offset : offset + itemsize * item_count]
|
||||
|
||||
def get_undequanted_tensor_and_ggml_type(self, name):
|
||||
t = self.tensor_info[name]
|
||||
data = self.get_mmap_tensor(name)
|
||||
ggml_type = t["ggml_type"]
|
||||
data = torch.from_numpy(data)
|
||||
return data, ggml_type
|
||||
|
||||
def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "gpu")->torch.Tensor:
|
||||
t = self.tensor_info[name]
|
||||
if device.lower() == "cpu":
|
||||
|
|
86
ktransformers/util/custom_loader.py
Normal file
86
ktransformers/util/custom_loader.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
import struct
|
||||
import warnings
|
||||
import numpy as np
|
||||
import re
|
||||
import numpy.typing as npt
|
||||
from typing import Sequence
|
||||
import os
|
||||
from enum import IntEnum
|
||||
import torch
|
||||
import KTransformersOps
|
||||
from safetensors import safe_open
|
||||
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
|
||||
from safetensors.torch import save_file
|
||||
|
||||
class SafeTensorLoader:
|
||||
tensor_file_map = {}
|
||||
tensor_type_map = {}
|
||||
file_handle_map = {}
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
self.__load_tensor_file_map(file_path)
|
||||
|
||||
def __load_tensor_file_map(self, file_path: str):
|
||||
# 处理传入路径,确保是文件夹路径
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"Path not found: {file_path}")
|
||||
if os.path.isfile(file_path):
|
||||
folder_path = os.path.dirname(file_path)
|
||||
else:
|
||||
folder_path = file_path
|
||||
|
||||
found_safetensor = False
|
||||
for root, _, files in os.walk(folder_path):
|
||||
files = sorted(files)
|
||||
for file in files:
|
||||
if file.endswith(".safetensors"):
|
||||
found_safetensor = True
|
||||
file_path = os.path.join(root, file)
|
||||
if file not in self.file_handle_map:
|
||||
try:
|
||||
handle = safe_open(file_path, framework="pt")
|
||||
self.file_handle_map[file] = handle
|
||||
except Exception as e:
|
||||
print(f"Error opening Safetensor file {file_path}: {e}")
|
||||
continue
|
||||
|
||||
f = self.file_handle_map.get(file)
|
||||
if f is None:
|
||||
continue
|
||||
try:
|
||||
for key in f.keys():
|
||||
self.tensor_file_map[key] = file
|
||||
except Exception as e:
|
||||
print(f"Error reading Safetensor file {file_path}: {e}")
|
||||
|
||||
# if not found_safetensor:
|
||||
# raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
|
||||
|
||||
def load_tensor(self, key: str, device: str="cpu"):
|
||||
if key not in self.tensor_file_map:
|
||||
raise KeyError(f"Key {key} not found in Safetensor files")
|
||||
file = self.tensor_file_map[key]
|
||||
f = self.file_handle_map.get(file)
|
||||
if f is None:
|
||||
raise FileNotFoundError(f"File {file} not found in Safetensor files")
|
||||
tensor = f.get_tensor(key)
|
||||
return tensor.to(device)
|
||||
|
||||
def close_all_handles(self):
|
||||
for handle in self.file_handle_map.values():
|
||||
handle.close()
|
||||
self.file_handle_map.clear()
|
||||
|
||||
def load_dequantized_tensor(self, key:str, device: str="cpu"):
|
||||
if key not in self.tensor_file_map:
|
||||
raise KeyError(f"Key {key} not found in Safetensor files")
|
||||
file = self.tensor_file_map[key]
|
||||
f = self.file_handle_map.get(file)
|
||||
if f is None:
|
||||
raise FileNotFoundError(f"File {file} not found in Safetensor files")
|
||||
tensor = f.get_tensor(key).to(device)
|
||||
if key.endswith(".weight"):
|
||||
if key[:-7] + ".weight_scale_inv" in self.tensor_file_map:
|
||||
weight_scale_inv = f.get_tensor(key[:-7] + ".weight_scale_inv").to(device)
|
||||
tensor = weight_dequant(tensor, weight_scale_inv)
|
||||
return tensor.to(device)
|
|
@ -66,12 +66,23 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
|
|||
for name, param in local_state.items():
|
||||
key = prefix + name
|
||||
translated_key = translate_name_to_gguf(key)
|
||||
if translated_key in gguf_loader.tensor_file_map:
|
||||
|
||||
# TODO: Merge all loader.
|
||||
# I know this is ugly but lets do it for now.
|
||||
if gguf_loader.safetensor_loader is not None:
|
||||
load_dequantized_tensor = gguf_loader.safetensor_loader.load_dequantized_tensor
|
||||
tensor_file_map = gguf_loader.safetensor_loader.tensor_file_map
|
||||
else:
|
||||
load_dequantized_tensor = gguf_loader.load_gguf_tensor
|
||||
tensor_file_map = gguf_loader.tensor_file_map
|
||||
|
||||
if translated_key in tensor_file_map:
|
||||
target_dtype = torch.get_default_dtype()
|
||||
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
|
||||
print(f"loading {translated_key} to {device}")
|
||||
torch.cuda.empty_cache() # To fit in 16G VRAM. By "wkGCaSS - 知乎 https://zhuanlan.zhihu.com/p/25491611225"
|
||||
weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
|
||||
# weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
|
||||
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
|
||||
set_param(module, name, weights)
|
||||
del weights
|
||||
else:
|
||||
|
|
214
merge_tensors/merge_safetensor_gguf.py
Normal file
214
merge_tensors/merge_safetensor_gguf.py
Normal file
|
@ -0,0 +1,214 @@
|
|||
# this script targets to merge the fp8 safe tensor and the gguf quantized tensors.
|
||||
|
||||
import os
|
||||
# insert the path of the project
|
||||
import sys
|
||||
sys.path.insert(0, "/home/azure/ktransformers")
|
||||
import argparse
|
||||
import torch
|
||||
from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
def read_safetensor_keys_from_folder(folder_path)->dict:
|
||||
"""
|
||||
:param folder_path: folder path
|
||||
:return: key_to_file_map
|
||||
"""
|
||||
# check if the folder path is exist
|
||||
if not os.path.exists(folder_path):
|
||||
raise FileNotFoundError(f"GGUF dir not found: {folder_path}")
|
||||
if os.path.isfile(folder_path):
|
||||
folder_path = os.path.dirname(folder_path)
|
||||
|
||||
key_to_file_map = {}
|
||||
|
||||
found_safetensor = False
|
||||
for root, dirs, files in os.walk(folder_path):
|
||||
# sort files
|
||||
files = sorted(files)
|
||||
for file in files:
|
||||
if file.endswith(".safetensors"):
|
||||
found_safetensor = True
|
||||
file_path = os.path.join(root, file)
|
||||
try:
|
||||
with safe_open(file_path, framework="pt") as f:
|
||||
for key in f.keys():
|
||||
if "model.layers.61" in key:
|
||||
# skip MTP layer
|
||||
continue
|
||||
# try:
|
||||
# if int(key.split('.')[2]) > 4:
|
||||
# continue
|
||||
# except:
|
||||
# pass
|
||||
key_to_file_map[key] = file_path
|
||||
except Exception as e:
|
||||
print(f"Error reading Safetensor file {file_path}: {e}")
|
||||
|
||||
if not found_safetensor:
|
||||
raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
|
||||
|
||||
return key_to_file_map
|
||||
|
||||
tensor_from_gguf = [] # todo: add keys in gguf that should be used in the final tensor
|
||||
|
||||
def translate_name(name:str)->str:
|
||||
"""
|
||||
:param name: name of the tensor
|
||||
:return: translated name
|
||||
"""
|
||||
name = translate_name_to_gguf(name)
|
||||
name = name.replace(".up_proj.", ".ffn_up_exps.")
|
||||
name = name.replace(".down_proj.", ".ffn_down_exps.")
|
||||
name = name.replace(".gate_proj.", ".ffn_gate_exps.")
|
||||
name = name.replace(".ffn_gate_inp.e_score_correction_bias", ".exp_probs_b.bias")
|
||||
return name
|
||||
|
||||
|
||||
def combine_tensor_sources(safetensor_path:str, gguf_path:str):
|
||||
gguf_loader = GGUFLoader(gguf_path)
|
||||
gguf_tensor_file_map = gguf_loader.tensor_file_map
|
||||
safetensor_tensor_file_map = read_safetensor_keys_from_folder(safetensor_path)
|
||||
|
||||
# build a map for the key to the tensor
|
||||
# according to the key, we can get the tensor from the file
|
||||
|
||||
target_tensor_map = {}
|
||||
for key in safetensor_tensor_file_map.keys():
|
||||
# for all experts, we use the gguf tensor
|
||||
if ".mlp.experts." in key:
|
||||
if '.weight_scale_inv' in key:
|
||||
continue
|
||||
key = '.'.join(key.split('.')[:5]+key.split('.')[-2:])
|
||||
translated_key = translate_name(key)
|
||||
target_tensor_map[key] = gguf_tensor_file_map[translated_key]
|
||||
continue
|
||||
|
||||
if any(target_key in key for target_key in tensor_from_gguf):
|
||||
target_tensor_map[key] = gguf_tensor_file_map[translate_name(key)]
|
||||
else:
|
||||
target_tensor_map[key] = safetensor_tensor_file_map[key]
|
||||
|
||||
return target_tensor_map, gguf_loader
|
||||
|
||||
def write_combined_tensor(target_tensor_map: dict, output_path: str, gguf_loader: GGUFLoader):
|
||||
# Ensure output directory exists
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
# Cache for safetensor file handles and GGUF loaders
|
||||
safetensors_cache = {}
|
||||
gguf_cache = {}
|
||||
|
||||
# Group tensors by layer
|
||||
layer_groups = defaultdict(list)
|
||||
non_layer_keys = []
|
||||
layer_pattern = re.compile(r'\.layers\.(\d+)\.')
|
||||
|
||||
for key in target_tensor_map:
|
||||
match = layer_pattern.search(key)
|
||||
if match:
|
||||
layer_num = int(match.group(1))
|
||||
layer_groups[layer_num].append(key)
|
||||
else:
|
||||
non_layer_keys.append(key)
|
||||
|
||||
# Calculate total shards
|
||||
total_shards = len(layer_groups) + (1 if non_layer_keys else 0) - 1
|
||||
if total_shards == 0:
|
||||
raise ValueError("No tensors to save")
|
||||
|
||||
shard_idx = 0
|
||||
|
||||
# Save non-layer tensors to the first shard if they exist
|
||||
if non_layer_keys:
|
||||
tensors = {}
|
||||
for key in non_layer_keys:
|
||||
file_path = target_tensor_map[key]
|
||||
tensor = None
|
||||
ggml_type = None
|
||||
if file_path.endswith('.safetensors'):
|
||||
if file_path not in safetensors_cache:
|
||||
safetensors_cache[file_path] = safe_open(file_path, framework='pt')
|
||||
f = safetensors_cache[file_path]
|
||||
tensor = f.get_tensor(key)
|
||||
elif file_path.endswith('.gguf'):
|
||||
gguf_name = translate_name(key)
|
||||
tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file format: {file_path}")
|
||||
tensors[translate_name(key)] = tensor
|
||||
if ggml_type:
|
||||
ggml_type = torch.tensor(ggml_type)
|
||||
ggml_key = translate_name(key)[:-7] + ".ggml_type" if translate_name(key).endswith(".weight") else translate_name(key) + ".ggml_type"
|
||||
tensors[ggml_key] = ggml_type
|
||||
|
||||
output_file = os.path.join(output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors")
|
||||
print(f"Saving non-layer tensors to {output_file}")
|
||||
save_file(tensors, output_file)
|
||||
print(tensors.keys())
|
||||
|
||||
shard_idx += 1
|
||||
|
||||
# Save each layer's tensors to subsequent shards
|
||||
for layer_num in sorted(layer_groups.keys()):
|
||||
layer_keys = layer_groups[layer_num]
|
||||
tensors = {}
|
||||
for key in layer_keys:
|
||||
file_path = target_tensor_map[key]
|
||||
tensor = None
|
||||
ggml_type = None
|
||||
if file_path.endswith('.safetensors'):
|
||||
if file_path not in safetensors_cache:
|
||||
safetensors_cache[file_path] = safe_open(file_path, framework='pt')
|
||||
f = safetensors_cache[file_path]
|
||||
tensor = f.get_tensor(key)
|
||||
tensor_info = tensor.shape
|
||||
elif file_path.endswith('.gguf'):
|
||||
gguf_name = translate_name(key)
|
||||
tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name)
|
||||
# tensor_info = gguf_loader.tensor_info[gguf_name]
|
||||
# ggml_type = gguf_loader.tensor_info[gguf_name]['ggml_type']
|
||||
else:
|
||||
raise ValueError(f"Unsupported file format: {file_path}")
|
||||
tensors[translate_name(key)] = tensor
|
||||
if ggml_type:
|
||||
ggml_type = torch.tensor(ggml_type)
|
||||
ggml_key = translate_name(key)[:-7] + ".ggml_type" if translate_name(key).endswith(".weight") else translate_name(key) + ".ggml_type"
|
||||
tensors[ggml_key] = ggml_type
|
||||
|
||||
output_file = os.path.join(output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors")
|
||||
print(f"Saving layer {layer_num} to {output_file}")
|
||||
print(tensors.keys())
|
||||
save_file(tensors, output_file)
|
||||
shard_idx += 1
|
||||
|
||||
return
|
||||
|
||||
def main():
|
||||
# 创建命令行参数解析器
|
||||
parser = argparse.ArgumentParser(description="Read parameters from Safetensor and GGUF files")
|
||||
parser.add_argument("--safetensor_path", type=str, help="Path to the Safetensor file", default="/mnt/data/model/DeepSeek-V3")
|
||||
parser.add_argument("--gguf_path", type=str, help="Path to the GGUF file", default="/mnt/data/model/DeepseekV3-q4km-gguf")
|
||||
parser.add_argument("--output_path", type=str, help="Path to the output file", default="/mnt/data/model/ktrans-safetensors/DeepSeek-V3-q4km-fp8")
|
||||
|
||||
# print all the arguments
|
||||
print("All the arguments:")
|
||||
print(parser.parse_args())
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
safetensor_path = args.safetensor_path
|
||||
gguf_path = args.gguf_path
|
||||
output_path = args.output_path
|
||||
|
||||
target_tensor_map, gguf_loader = combine_tensor_sources(safetensor_path, gguf_path)
|
||||
write_combined_tensor(target_tensor_map, output_path, gguf_loader)
|
||||
|
||||
return
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Add table
Add a link
Reference in a new issue