diff --git a/ktransformers/ktransformers_ext/triton/fp8gemm.py b/ktransformers/ktransformers_ext/triton/fp8gemm.py index 4da4cfe..7d5b72e 100644 --- a/ktransformers/ktransformers_ext/triton/fp8gemm.py +++ b/ktransformers/ktransformers_ext/triton/fp8gemm.py @@ -1,3 +1,4 @@ +# Adopted from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py from typing import Tuple import torch diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 21b4830..1ea244a 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -245,7 +245,16 @@ class KExpertsCPU(KExpertsBase): down_type = None for key in keys: - if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info: + if self.gguf_loader.safetensor_loader is not None: + # using a temp ugly way to temprary load the tensor + gate = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.weight").numpy() + up = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.weight").numpy() + down = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.weight").numpy() + gate_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.ggml_type").item() + up_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.ggml_type").item() + down_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.ggml_type").item() + + elif key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info: gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight") up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight") down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight") diff --git a/ktransformers/operators/gate.py b/ktransformers/operators/gate.py index 52bb33a..d908093 100644 --- a/ktransformers/operators/gate.py +++ b/ktransformers/operators/gate.py @@ -67,7 +67,14 @@ class KMoEGateBase(ABC): for key in keys: key = ".".join(key.split(".")[:-1]) - if key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info: + if self.gguf_loader.safetensor_loader is not None: + targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"] + weight = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_inp.weight") + e_score_correction_bias = self.gguf_loader.safetensor_loader.load_tensor(key + ".exp_probs_b.bias") + weight_type = weight.dtype + e_score_correction_bias_type = e_score_correction_bias.dtype + res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type} + elif key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info: targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"] tensors = self.load_multi(key, targets, device=device) weight = tensors[".ffn_gate_inp.weight"] @@ -116,8 +123,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase): self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"]) else: raise ValueError("Invalid weight type") - self.orig_module.weight = self.orig_module.weight.to(device) - self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device) + self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device)) + self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device)) def unload(self): if self.weight is not None: diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index 5aff964..e778102 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -76,7 +76,13 @@ class KLinearBase(ABC): keys = [self.key] for key in keys: - if key + ".weight" in self.gguf_loader.tensor_file_map: + if self.gguf_loader.safetensor_loader is not None: + # using safetensor_loader + tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight') + weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv') + return nn.Parameter(tensor), nn.Parameter(weight_scale_inv) + + elif key + ".weight" in self.gguf_loader.tensor_file_map: if key + ".bias" in self.gguf_loader.tensor_file_map: tensors = self.load_multi(key, ["weight", "bias"], device=device) tensor = tensors["weight"] @@ -166,6 +172,8 @@ class KLinearTorch(KLinearBase): self.bias = None class KLinearFP8(KLinearBase): + # this kernel requires special handling for weight + # Please load the weight file downloaded from KVCache.AI marlin_q_w: torch.Tensor marlin_s: torch.Tensor g_idx: torch.Tensor @@ -191,26 +199,20 @@ class KLinearFP8(KLinearBase): def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.to(self.device) - orig_shape = list(x.shape) - orig_dtype = x.dtype - x = x.reshape(-1, orig_shape[-1]) + orig_dtype = x.dtype x_quantized, scale_x = act_quant(x, self.block_size) - y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight.scale) - if self.bias is not None: - y += self.bias - return y.to(orig_dtype).reshape(orig_shape) + y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight_scale_inv) + return y.to(dtype=orig_dtype) def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): if device is None: device = self.device if w is None: w = self.load_weight(device=device) - if isinstance(w, nn.Parameter): - self.weight = w.to(device) - self.has_bias = False - elif isinstance(w, tuple): + ### TODO fit weight_inv format + if isinstance(w, tuple): self.weight = w[0].to(device) - self.bias = w[1].to(device) - self.has_bias = True + self.weight_scale_inv = w[1].to(device) + self.has_bias = False else: raise ValueError("Invalid weight type") self.weight = self.weight.to(device) @@ -425,7 +427,8 @@ class KLinearCPUInfer(KLinearBase): LINEAR_MAP = { "KLinearMarlin": KLinearMarlin, "KLinearTorch": KLinearTorch, - "KLinearCPUInfer": KLinearCPUInfer + "KLinearCPUInfer": KLinearCPUInfer, + "KLinearFP8": KLinearFP8, } class KTransformersLinear(BaseInjectedModule, KLinearBase): @@ -472,10 +475,11 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase): def forward(self, x): if self.mode == InferenceState.PREFILL: assert self.prefill_linear is not None, "cpu linear is not initialized" - return self.prefill_linear.forward(x) + y = self.prefill_linear.forward(x) else: assert self.generate_linear is not None, "gpu linear is not initialized" - return self.generate_linear.forward(x) + y = self.generate_linear.forward(x) + return y def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE): if not mode: diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml new file mode 100644 index 0000000..25f021e --- /dev/null +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml @@ -0,0 +1,63 @@ +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearFP8" + prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE + replace: + class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + class: ktransformers.models.modeling_deepseek_v3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGate + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" \ No newline at end of file diff --git a/ktransformers/tests/triton_fp8gemm_test.py b/ktransformers/tests/triton_fp8gemm_test.py index bb3801c..58888d6 100644 --- a/ktransformers/tests/triton_fp8gemm_test.py +++ b/ktransformers/tests/triton_fp8gemm_test.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from typing import Optional import pytest from typing import Tuple, Optional, Literal - +import time # use dir path import os import sys @@ -56,18 +56,61 @@ def test_fp8_gemm_vs_torch_matmul_load(): print(f"weight_dequantized: {weight_dequantized.shape}") N, K = weight_dequantized.shape M = 64 - x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda') + x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda') x_quantized, scale_x = act_quant(x, block_size) # Test case 1: quantized x matmal with undequantized weight result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale) print(f"result_fp8_gemm:\n {result_fp8_gemm}") + print(f"dtype {result_fp8_gemm.dtype}") # Perform torch.matmul using the original floating point tensors result_torch_matmul = torch.matmul(x, weight_dequantized.to(torch.bfloat16).T) print(f"result_torch_matmul:\n {result_torch_matmul}") +def test_fp8_gemm_tplops(): + file_path = "/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors" + with safe_open(file_path, framework="pt", device=0) as f: + weight = f.get_tensor("model.layers.0.mlp.down_proj.weight") + scale = f.get_tensor("model.layers.0.mlp.down_proj.weight_scale_inv") + + # weight_dequant + weight_dequantized = weight_dequant(weight, scale) + print(f"weight_dequantized: {weight_dequantized.shape}") + N, K = weight_dequantized.shape + M = 6400 + x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda') + # x_quantized, scale_x = act_quant(x, block_size) + + # Calculate time for 1000 fp8_gemm + i = 10 + flops_per_gemm = 2 * M * N * K + total_flops = i * flops_per_gemm + + x_quantized, scale_x = act_quant(x, block_size) + result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale) + x_quantized, scale_x = act_quant(x, block_size) + result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale) + + + t0 = time.time() + torch.cuda.synchronize() + for i in range(i): + x_quantized, scale_x = act_quant(x, block_size) + result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale) + torch.cuda.synchronize() + t1 = time.time() + + total_time = t1 - t0 + tflops = total_flops / total_time / 1e12 + print(f"total_time: {total_time}") + print(f"tflops: {tflops}") + + + + if __name__ == "__main__": test_fp8_gemm_vs_torch_matmul() test_fp8_gemm_vs_torch_matmul_load() + test_fp8_gemm_tplops() \ No newline at end of file diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index 26afd39..d054ad3 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -25,6 +25,7 @@ import os from enum import IntEnum import torch import KTransformersOps +from .custom_loader import SafeTensorLoader class GGMLQuantizationType(IntEnum): F32 = 0 @@ -168,12 +169,15 @@ class GGUFLoader: gguf_path: str tensor_file_map: dict # {tensor_name: tensor_file_path} gguf_file_meta: dict + safetensor_loader: SafeTensorLoader def __init__(self, gguf_path: str): # Check dir exist if not os.path.exists(gguf_path): raise FileNotFoundError(f"GGUF dir not found: {gguf_path}") if os.path.isfile(gguf_path): gguf_path = os.path.dirname(gguf_path) + + self.safetensor_loader = None self.tensor_info = {} self.gguf_path = gguf_path @@ -181,7 +185,13 @@ class GGUFLoader: self.file_data_map = {} self.gguf_file_meta = {} self.tensor_device_map = {} - + + # I know this is ugly, but I don't want to change the original code too much + # TODO: merge gguf load and other loads. + safetensor_loader = SafeTensorLoader(gguf_path) + if safetensor_loader.tensor_file_map: + self.safetensor_loader = safetensor_loader + return # Walk through all the .gguf files in the directory found_gguf = False for root, dirs, files in os.walk(gguf_path): @@ -288,6 +298,13 @@ class GGUFLoader: itemsize = int(np.empty([], dtype = item_type).itemsize) return mmap_data[offset : offset + itemsize * item_count] + def get_undequanted_tensor_and_ggml_type(self, name): + t = self.tensor_info[name] + data = self.get_mmap_tensor(name) + ggml_type = t["ggml_type"] + data = torch.from_numpy(data) + return data, ggml_type + def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "gpu")->torch.Tensor: t = self.tensor_info[name] if device.lower() == "cpu": diff --git a/ktransformers/util/custom_loader.py b/ktransformers/util/custom_loader.py new file mode 100644 index 0000000..ecc09a0 --- /dev/null +++ b/ktransformers/util/custom_loader.py @@ -0,0 +1,86 @@ +import struct +import warnings +import numpy as np +import re +import numpy.typing as npt +from typing import Sequence +import os +from enum import IntEnum +import torch +import KTransformersOps +from safetensors import safe_open +from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant +from safetensors.torch import save_file + +class SafeTensorLoader: + tensor_file_map = {} + tensor_type_map = {} + file_handle_map = {} + + def __init__(self, file_path: str): + self.__load_tensor_file_map(file_path) + + def __load_tensor_file_map(self, file_path: str): + # 处理传入路径,确保是文件夹路径 + if not os.path.exists(file_path): + raise FileNotFoundError(f"Path not found: {file_path}") + if os.path.isfile(file_path): + folder_path = os.path.dirname(file_path) + else: + folder_path = file_path + + found_safetensor = False + for root, _, files in os.walk(folder_path): + files = sorted(files) + for file in files: + if file.endswith(".safetensors"): + found_safetensor = True + file_path = os.path.join(root, file) + if file not in self.file_handle_map: + try: + handle = safe_open(file_path, framework="pt") + self.file_handle_map[file] = handle + except Exception as e: + print(f"Error opening Safetensor file {file_path}: {e}") + continue + + f = self.file_handle_map.get(file) + if f is None: + continue + try: + for key in f.keys(): + self.tensor_file_map[key] = file + except Exception as e: + print(f"Error reading Safetensor file {file_path}: {e}") + + # if not found_safetensor: + # raise FileNotFoundError(f"No Safetensor files found in {folder_path}") + + def load_tensor(self, key: str, device: str="cpu"): + if key not in self.tensor_file_map: + raise KeyError(f"Key {key} not found in Safetensor files") + file = self.tensor_file_map[key] + f = self.file_handle_map.get(file) + if f is None: + raise FileNotFoundError(f"File {file} not found in Safetensor files") + tensor = f.get_tensor(key) + return tensor.to(device) + + def close_all_handles(self): + for handle in self.file_handle_map.values(): + handle.close() + self.file_handle_map.clear() + + def load_dequantized_tensor(self, key:str, device: str="cpu"): + if key not in self.tensor_file_map: + raise KeyError(f"Key {key} not found in Safetensor files") + file = self.tensor_file_map[key] + f = self.file_handle_map.get(file) + if f is None: + raise FileNotFoundError(f"File {file} not found in Safetensor files") + tensor = f.get_tensor(key).to(device) + if key.endswith(".weight"): + if key[:-7] + ".weight_scale_inv" in self.tensor_file_map: + weight_scale_inv = f.get_tensor(key[:-7] + ".weight_scale_inv").to(device) + tensor = weight_dequant(tensor, weight_scale_inv) + return tensor.to(device) \ No newline at end of file diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index 81d007c..1c21135 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -66,12 +66,23 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str for name, param in local_state.items(): key = prefix + name translated_key = translate_name_to_gguf(key) - if translated_key in gguf_loader.tensor_file_map: + + # TODO: Merge all loader. + # I know this is ugly but lets do it for now. + if gguf_loader.safetensor_loader is not None: + load_dequantized_tensor = gguf_loader.safetensor_loader.load_dequantized_tensor + tensor_file_map = gguf_loader.safetensor_loader.tensor_file_map + else: + load_dequantized_tensor = gguf_loader.load_gguf_tensor + tensor_file_map = gguf_loader.tensor_file_map + + if translated_key in tensor_file_map: target_dtype = torch.get_default_dtype() device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map) print(f"loading {translated_key} to {device}") torch.cuda.empty_cache() # To fit in 16G VRAM. By "wkGCaSS - 知乎 https://zhuanlan.zhihu.com/p/25491611225" - weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype) + # weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype) + weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype) set_param(module, name, weights) del weights else: diff --git a/merge_tensors/merge_safetensor_gguf.py b/merge_tensors/merge_safetensor_gguf.py new file mode 100644 index 0000000..7aeb62d --- /dev/null +++ b/merge_tensors/merge_safetensor_gguf.py @@ -0,0 +1,214 @@ +# this script targets to merge the fp8 safe tensor and the gguf quantized tensors. + +import os +# insert the path of the project +import sys +sys.path.insert(0, "/home/azure/ktransformers") +import argparse +import torch +from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf +from safetensors import safe_open +from safetensors.torch import save_file +import re +from collections import defaultdict + +def read_safetensor_keys_from_folder(folder_path)->dict: + """ + :param folder_path: folder path + :return: key_to_file_map + """ + # check if the folder path is exist + if not os.path.exists(folder_path): + raise FileNotFoundError(f"GGUF dir not found: {folder_path}") + if os.path.isfile(folder_path): + folder_path = os.path.dirname(folder_path) + + key_to_file_map = {} + + found_safetensor = False + for root, dirs, files in os.walk(folder_path): + # sort files + files = sorted(files) + for file in files: + if file.endswith(".safetensors"): + found_safetensor = True + file_path = os.path.join(root, file) + try: + with safe_open(file_path, framework="pt") as f: + for key in f.keys(): + if "model.layers.61" in key: + # skip MTP layer + continue + # try: + # if int(key.split('.')[2]) > 4: + # continue + # except: + # pass + key_to_file_map[key] = file_path + except Exception as e: + print(f"Error reading Safetensor file {file_path}: {e}") + + if not found_safetensor: + raise FileNotFoundError(f"No Safetensor files found in {folder_path}") + + return key_to_file_map + +tensor_from_gguf = [] # todo: add keys in gguf that should be used in the final tensor + +def translate_name(name:str)->str: + """ + :param name: name of the tensor + :return: translated name + """ + name = translate_name_to_gguf(name) + name = name.replace(".up_proj.", ".ffn_up_exps.") + name = name.replace(".down_proj.", ".ffn_down_exps.") + name = name.replace(".gate_proj.", ".ffn_gate_exps.") + name = name.replace(".ffn_gate_inp.e_score_correction_bias", ".exp_probs_b.bias") + return name + + +def combine_tensor_sources(safetensor_path:str, gguf_path:str): + gguf_loader = GGUFLoader(gguf_path) + gguf_tensor_file_map = gguf_loader.tensor_file_map + safetensor_tensor_file_map = read_safetensor_keys_from_folder(safetensor_path) + + # build a map for the key to the tensor + # according to the key, we can get the tensor from the file + + target_tensor_map = {} + for key in safetensor_tensor_file_map.keys(): + # for all experts, we use the gguf tensor + if ".mlp.experts." in key: + if '.weight_scale_inv' in key: + continue + key = '.'.join(key.split('.')[:5]+key.split('.')[-2:]) + translated_key = translate_name(key) + target_tensor_map[key] = gguf_tensor_file_map[translated_key] + continue + + if any(target_key in key for target_key in tensor_from_gguf): + target_tensor_map[key] = gguf_tensor_file_map[translate_name(key)] + else: + target_tensor_map[key] = safetensor_tensor_file_map[key] + + return target_tensor_map, gguf_loader + +def write_combined_tensor(target_tensor_map: dict, output_path: str, gguf_loader: GGUFLoader): + # Ensure output directory exists + os.makedirs(output_path, exist_ok=True) + + # Cache for safetensor file handles and GGUF loaders + safetensors_cache = {} + gguf_cache = {} + + # Group tensors by layer + layer_groups = defaultdict(list) + non_layer_keys = [] + layer_pattern = re.compile(r'\.layers\.(\d+)\.') + + for key in target_tensor_map: + match = layer_pattern.search(key) + if match: + layer_num = int(match.group(1)) + layer_groups[layer_num].append(key) + else: + non_layer_keys.append(key) + + # Calculate total shards + total_shards = len(layer_groups) + (1 if non_layer_keys else 0) - 1 + if total_shards == 0: + raise ValueError("No tensors to save") + + shard_idx = 0 + + # Save non-layer tensors to the first shard if they exist + if non_layer_keys: + tensors = {} + for key in non_layer_keys: + file_path = target_tensor_map[key] + tensor = None + ggml_type = None + if file_path.endswith('.safetensors'): + if file_path not in safetensors_cache: + safetensors_cache[file_path] = safe_open(file_path, framework='pt') + f = safetensors_cache[file_path] + tensor = f.get_tensor(key) + elif file_path.endswith('.gguf'): + gguf_name = translate_name(key) + tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name) + else: + raise ValueError(f"Unsupported file format: {file_path}") + tensors[translate_name(key)] = tensor + if ggml_type: + ggml_type = torch.tensor(ggml_type) + ggml_key = translate_name(key)[:-7] + ".ggml_type" if translate_name(key).endswith(".weight") else translate_name(key) + ".ggml_type" + tensors[ggml_key] = ggml_type + + output_file = os.path.join(output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors") + print(f"Saving non-layer tensors to {output_file}") + save_file(tensors, output_file) + print(tensors.keys()) + + shard_idx += 1 + + # Save each layer's tensors to subsequent shards + for layer_num in sorted(layer_groups.keys()): + layer_keys = layer_groups[layer_num] + tensors = {} + for key in layer_keys: + file_path = target_tensor_map[key] + tensor = None + ggml_type = None + if file_path.endswith('.safetensors'): + if file_path not in safetensors_cache: + safetensors_cache[file_path] = safe_open(file_path, framework='pt') + f = safetensors_cache[file_path] + tensor = f.get_tensor(key) + tensor_info = tensor.shape + elif file_path.endswith('.gguf'): + gguf_name = translate_name(key) + tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name) + # tensor_info = gguf_loader.tensor_info[gguf_name] + # ggml_type = gguf_loader.tensor_info[gguf_name]['ggml_type'] + else: + raise ValueError(f"Unsupported file format: {file_path}") + tensors[translate_name(key)] = tensor + if ggml_type: + ggml_type = torch.tensor(ggml_type) + ggml_key = translate_name(key)[:-7] + ".ggml_type" if translate_name(key).endswith(".weight") else translate_name(key) + ".ggml_type" + tensors[ggml_key] = ggml_type + + output_file = os.path.join(output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors") + print(f"Saving layer {layer_num} to {output_file}") + print(tensors.keys()) + save_file(tensors, output_file) + shard_idx += 1 + + return + +def main(): + # 创建命令行参数解析器 + parser = argparse.ArgumentParser(description="Read parameters from Safetensor and GGUF files") + parser.add_argument("--safetensor_path", type=str, help="Path to the Safetensor file", default="/mnt/data/model/DeepSeek-V3") + parser.add_argument("--gguf_path", type=str, help="Path to the GGUF file", default="/mnt/data/model/DeepseekV3-q4km-gguf") + parser.add_argument("--output_path", type=str, help="Path to the output file", default="/mnt/data/model/ktrans-safetensors/DeepSeek-V3-q4km-fp8") + + # print all the arguments + print("All the arguments:") + print(parser.parse_args()) + + # 解析命令行参数 + args = parser.parse_args() + + safetensor_path = args.safetensor_path + gguf_path = args.gguf_path + output_path = args.output_path + + target_tensor_map, gguf_loader = combine_tensor_sources(safetensor_path, gguf_path) + write_combined_tensor(target_tensor_map, output_path, gguf_loader) + + return + +if __name__ == "__main__": + main() \ No newline at end of file