Add data loader to read special weights for fp8; Add special weight process script

This commit is contained in:
Azure 2025-02-24 11:16:23 +00:00
parent 7b7c6a657d
commit 581a524f65
10 changed files with 481 additions and 26 deletions

View file

@ -1,3 +1,4 @@
# Adopted from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
from typing import Tuple
import torch

View file

@ -245,7 +245,16 @@ class KExpertsCPU(KExpertsBase):
down_type = None
for key in keys:
if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
if self.gguf_loader.safetensor_loader is not None:
# using a temp ugly way to temprary load the tensor
gate = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.weight").numpy()
up = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.weight").numpy()
down = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.weight").numpy()
gate_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.ggml_type").item()
up_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.ggml_type").item()
down_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.ggml_type").item()
elif key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")

View file

@ -67,7 +67,14 @@ class KMoEGateBase(ABC):
for key in keys:
key = ".".join(key.split(".")[:-1])
if key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info:
if self.gguf_loader.safetensor_loader is not None:
targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
weight = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_inp.weight")
e_score_correction_bias = self.gguf_loader.safetensor_loader.load_tensor(key + ".exp_probs_b.bias")
weight_type = weight.dtype
e_score_correction_bias_type = e_score_correction_bias.dtype
res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type}
elif key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info:
targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
tensors = self.load_multi(key, targets, device=device)
weight = tensors[".ffn_gate_inp.weight"]
@ -116,8 +123,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
else:
raise ValueError("Invalid weight type")
self.orig_module.weight = self.orig_module.weight.to(device)
self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device)
self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device))
self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device))
def unload(self):
if self.weight is not None:

View file

@ -76,7 +76,13 @@ class KLinearBase(ABC):
keys = [self.key]
for key in keys:
if key + ".weight" in self.gguf_loader.tensor_file_map:
if self.gguf_loader.safetensor_loader is not None:
# using safetensor_loader
tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight')
weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv')
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
elif key + ".weight" in self.gguf_loader.tensor_file_map:
if key + ".bias" in self.gguf_loader.tensor_file_map:
tensors = self.load_multi(key, ["weight", "bias"], device=device)
tensor = tensors["weight"]
@ -166,6 +172,8 @@ class KLinearTorch(KLinearBase):
self.bias = None
class KLinearFP8(KLinearBase):
# this kernel requires special handling for weight
# Please load the weight file downloaded from KVCache.AI
marlin_q_w: torch.Tensor
marlin_s: torch.Tensor
g_idx: torch.Tensor
@ -191,26 +199,20 @@ class KLinearFP8(KLinearBase):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.to(self.device)
orig_shape = list(x.shape)
orig_dtype = x.dtype
x = x.reshape(-1, orig_shape[-1])
x_quantized, scale_x = act_quant(x, self.block_size)
y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight.scale)
if self.bias is not None:
y += self.bias
return y.to(orig_dtype).reshape(orig_shape)
y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight_scale_inv)
return y.to(dtype=orig_dtype)
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if device is None: device = self.device
if w is None:
w = self.load_weight(device=device)
if isinstance(w, nn.Parameter):
self.weight = w.to(device)
self.has_bias = False
elif isinstance(w, tuple):
### TODO fit weight_inv format
if isinstance(w, tuple):
self.weight = w[0].to(device)
self.bias = w[1].to(device)
self.has_bias = True
self.weight_scale_inv = w[1].to(device)
self.has_bias = False
else:
raise ValueError("Invalid weight type")
self.weight = self.weight.to(device)
@ -425,7 +427,8 @@ class KLinearCPUInfer(KLinearBase):
LINEAR_MAP = {
"KLinearMarlin": KLinearMarlin,
"KLinearTorch": KLinearTorch,
"KLinearCPUInfer": KLinearCPUInfer
"KLinearCPUInfer": KLinearCPUInfer,
"KLinearFP8": KLinearFP8,
}
class KTransformersLinear(BaseInjectedModule, KLinearBase):
@ -472,10 +475,11 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
def forward(self, x):
if self.mode == InferenceState.PREFILL:
assert self.prefill_linear is not None, "cpu linear is not initialized"
return self.prefill_linear.forward(x)
y = self.prefill_linear.forward(x)
else:
assert self.generate_linear is not None, "gpu linear is not initialized"
return self.generate_linear.forward(x)
y = self.generate_linear.forward(x)
return y
def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE):
if not mode:

View file

@ -0,0 +1,63 @@
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearFP8"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View file

@ -3,7 +3,7 @@ import torch.nn.functional as F
from typing import Optional
import pytest
from typing import Tuple, Optional, Literal
import time
# use dir path
import os
import sys
@ -56,18 +56,61 @@ def test_fp8_gemm_vs_torch_matmul_load():
print(f"weight_dequantized: {weight_dequantized.shape}")
N, K = weight_dequantized.shape
M = 64
x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda')
x_quantized, scale_x = act_quant(x, block_size)
# Test case 1: quantized x matmal with undequantized weight
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
print(f"result_fp8_gemm:\n {result_fp8_gemm}")
print(f"dtype {result_fp8_gemm.dtype}")
# Perform torch.matmul using the original floating point tensors
result_torch_matmul = torch.matmul(x, weight_dequantized.to(torch.bfloat16).T)
print(f"result_torch_matmul:\n {result_torch_matmul}")
def test_fp8_gemm_tplops():
file_path = "/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors"
with safe_open(file_path, framework="pt", device=0) as f:
weight = f.get_tensor("model.layers.0.mlp.down_proj.weight")
scale = f.get_tensor("model.layers.0.mlp.down_proj.weight_scale_inv")
# weight_dequant
weight_dequantized = weight_dequant(weight, scale)
print(f"weight_dequantized: {weight_dequantized.shape}")
N, K = weight_dequantized.shape
M = 6400
x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda')
# x_quantized, scale_x = act_quant(x, block_size)
# Calculate time for 1000 fp8_gemm
i = 10
flops_per_gemm = 2 * M * N * K
total_flops = i * flops_per_gemm
x_quantized, scale_x = act_quant(x, block_size)
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
x_quantized, scale_x = act_quant(x, block_size)
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
t0 = time.time()
torch.cuda.synchronize()
for i in range(i):
x_quantized, scale_x = act_quant(x, block_size)
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
torch.cuda.synchronize()
t1 = time.time()
total_time = t1 - t0
tflops = total_flops / total_time / 1e12
print(f"total_time: {total_time}")
print(f"tflops: {tflops}")
if __name__ == "__main__":
test_fp8_gemm_vs_torch_matmul()
test_fp8_gemm_vs_torch_matmul_load()
test_fp8_gemm_tplops()

View file

@ -25,6 +25,7 @@ import os
from enum import IntEnum
import torch
import KTransformersOps
from .custom_loader import SafeTensorLoader
class GGMLQuantizationType(IntEnum):
F32 = 0
@ -168,6 +169,7 @@ class GGUFLoader:
gguf_path: str
tensor_file_map: dict # {tensor_name: tensor_file_path}
gguf_file_meta: dict
safetensor_loader: SafeTensorLoader
def __init__(self, gguf_path: str):
# Check dir exist
if not os.path.exists(gguf_path):
@ -175,6 +177,8 @@ class GGUFLoader:
if os.path.isfile(gguf_path):
gguf_path = os.path.dirname(gguf_path)
self.safetensor_loader = None
self.tensor_info = {}
self.gguf_path = gguf_path
self.tensor_file_map = {}
@ -182,6 +186,12 @@ class GGUFLoader:
self.gguf_file_meta = {}
self.tensor_device_map = {}
# I know this is ugly, but I don't want to change the original code too much
# TODO: merge gguf load and other loads.
safetensor_loader = SafeTensorLoader(gguf_path)
if safetensor_loader.tensor_file_map:
self.safetensor_loader = safetensor_loader
return
# Walk through all the .gguf files in the directory
found_gguf = False
for root, dirs, files in os.walk(gguf_path):
@ -288,6 +298,13 @@ class GGUFLoader:
itemsize = int(np.empty([], dtype = item_type).itemsize)
return mmap_data[offset : offset + itemsize * item_count]
def get_undequanted_tensor_and_ggml_type(self, name):
t = self.tensor_info[name]
data = self.get_mmap_tensor(name)
ggml_type = t["ggml_type"]
data = torch.from_numpy(data)
return data, ggml_type
def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "gpu")->torch.Tensor:
t = self.tensor_info[name]
if device.lower() == "cpu":

View file

@ -0,0 +1,86 @@
import struct
import warnings
import numpy as np
import re
import numpy.typing as npt
from typing import Sequence
import os
from enum import IntEnum
import torch
import KTransformersOps
from safetensors import safe_open
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
from safetensors.torch import save_file
class SafeTensorLoader:
tensor_file_map = {}
tensor_type_map = {}
file_handle_map = {}
def __init__(self, file_path: str):
self.__load_tensor_file_map(file_path)
def __load_tensor_file_map(self, file_path: str):
# 处理传入路径,确保是文件夹路径
if not os.path.exists(file_path):
raise FileNotFoundError(f"Path not found: {file_path}")
if os.path.isfile(file_path):
folder_path = os.path.dirname(file_path)
else:
folder_path = file_path
found_safetensor = False
for root, _, files in os.walk(folder_path):
files = sorted(files)
for file in files:
if file.endswith(".safetensors"):
found_safetensor = True
file_path = os.path.join(root, file)
if file not in self.file_handle_map:
try:
handle = safe_open(file_path, framework="pt")
self.file_handle_map[file] = handle
except Exception as e:
print(f"Error opening Safetensor file {file_path}: {e}")
continue
f = self.file_handle_map.get(file)
if f is None:
continue
try:
for key in f.keys():
self.tensor_file_map[key] = file
except Exception as e:
print(f"Error reading Safetensor file {file_path}: {e}")
# if not found_safetensor:
# raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
def load_tensor(self, key: str, device: str="cpu"):
if key not in self.tensor_file_map:
raise KeyError(f"Key {key} not found in Safetensor files")
file = self.tensor_file_map[key]
f = self.file_handle_map.get(file)
if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(key)
return tensor.to(device)
def close_all_handles(self):
for handle in self.file_handle_map.values():
handle.close()
self.file_handle_map.clear()
def load_dequantized_tensor(self, key:str, device: str="cpu"):
if key not in self.tensor_file_map:
raise KeyError(f"Key {key} not found in Safetensor files")
file = self.tensor_file_map[key]
f = self.file_handle_map.get(file)
if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(key).to(device)
if key.endswith(".weight"):
if key[:-7] + ".weight_scale_inv" in self.tensor_file_map:
weight_scale_inv = f.get_tensor(key[:-7] + ".weight_scale_inv").to(device)
tensor = weight_dequant(tensor, weight_scale_inv)
return tensor.to(device)

View file

@ -66,12 +66,23 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
for name, param in local_state.items():
key = prefix + name
translated_key = translate_name_to_gguf(key)
if translated_key in gguf_loader.tensor_file_map:
# TODO: Merge all loader.
# I know this is ugly but lets do it for now.
if gguf_loader.safetensor_loader is not None:
load_dequantized_tensor = gguf_loader.safetensor_loader.load_dequantized_tensor
tensor_file_map = gguf_loader.safetensor_loader.tensor_file_map
else:
load_dequantized_tensor = gguf_loader.load_gguf_tensor
tensor_file_map = gguf_loader.tensor_file_map
if translated_key in tensor_file_map:
target_dtype = torch.get_default_dtype()
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
print(f"loading {translated_key} to {device}")
torch.cuda.empty_cache() # To fit in 16G VRAM. By "wkGCaSS - 知乎 https://zhuanlan.zhihu.com/p/25491611225"
weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
# weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
set_param(module, name, weights)
del weights
else:

View file

@ -0,0 +1,214 @@
# this script targets to merge the fp8 safe tensor and the gguf quantized tensors.
import os
# insert the path of the project
import sys
sys.path.insert(0, "/home/azure/ktransformers")
import argparse
import torch
from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf
from safetensors import safe_open
from safetensors.torch import save_file
import re
from collections import defaultdict
def read_safetensor_keys_from_folder(folder_path)->dict:
"""
:param folder_path: folder path
:return: key_to_file_map
"""
# check if the folder path is exist
if not os.path.exists(folder_path):
raise FileNotFoundError(f"GGUF dir not found: {folder_path}")
if os.path.isfile(folder_path):
folder_path = os.path.dirname(folder_path)
key_to_file_map = {}
found_safetensor = False
for root, dirs, files in os.walk(folder_path):
# sort files
files = sorted(files)
for file in files:
if file.endswith(".safetensors"):
found_safetensor = True
file_path = os.path.join(root, file)
try:
with safe_open(file_path, framework="pt") as f:
for key in f.keys():
if "model.layers.61" in key:
# skip MTP layer
continue
# try:
# if int(key.split('.')[2]) > 4:
# continue
# except:
# pass
key_to_file_map[key] = file_path
except Exception as e:
print(f"Error reading Safetensor file {file_path}: {e}")
if not found_safetensor:
raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
return key_to_file_map
tensor_from_gguf = [] # todo: add keys in gguf that should be used in the final tensor
def translate_name(name:str)->str:
"""
:param name: name of the tensor
:return: translated name
"""
name = translate_name_to_gguf(name)
name = name.replace(".up_proj.", ".ffn_up_exps.")
name = name.replace(".down_proj.", ".ffn_down_exps.")
name = name.replace(".gate_proj.", ".ffn_gate_exps.")
name = name.replace(".ffn_gate_inp.e_score_correction_bias", ".exp_probs_b.bias")
return name
def combine_tensor_sources(safetensor_path:str, gguf_path:str):
gguf_loader = GGUFLoader(gguf_path)
gguf_tensor_file_map = gguf_loader.tensor_file_map
safetensor_tensor_file_map = read_safetensor_keys_from_folder(safetensor_path)
# build a map for the key to the tensor
# according to the key, we can get the tensor from the file
target_tensor_map = {}
for key in safetensor_tensor_file_map.keys():
# for all experts, we use the gguf tensor
if ".mlp.experts." in key:
if '.weight_scale_inv' in key:
continue
key = '.'.join(key.split('.')[:5]+key.split('.')[-2:])
translated_key = translate_name(key)
target_tensor_map[key] = gguf_tensor_file_map[translated_key]
continue
if any(target_key in key for target_key in tensor_from_gguf):
target_tensor_map[key] = gguf_tensor_file_map[translate_name(key)]
else:
target_tensor_map[key] = safetensor_tensor_file_map[key]
return target_tensor_map, gguf_loader
def write_combined_tensor(target_tensor_map: dict, output_path: str, gguf_loader: GGUFLoader):
# Ensure output directory exists
os.makedirs(output_path, exist_ok=True)
# Cache for safetensor file handles and GGUF loaders
safetensors_cache = {}
gguf_cache = {}
# Group tensors by layer
layer_groups = defaultdict(list)
non_layer_keys = []
layer_pattern = re.compile(r'\.layers\.(\d+)\.')
for key in target_tensor_map:
match = layer_pattern.search(key)
if match:
layer_num = int(match.group(1))
layer_groups[layer_num].append(key)
else:
non_layer_keys.append(key)
# Calculate total shards
total_shards = len(layer_groups) + (1 if non_layer_keys else 0) - 1
if total_shards == 0:
raise ValueError("No tensors to save")
shard_idx = 0
# Save non-layer tensors to the first shard if they exist
if non_layer_keys:
tensors = {}
for key in non_layer_keys:
file_path = target_tensor_map[key]
tensor = None
ggml_type = None
if file_path.endswith('.safetensors'):
if file_path not in safetensors_cache:
safetensors_cache[file_path] = safe_open(file_path, framework='pt')
f = safetensors_cache[file_path]
tensor = f.get_tensor(key)
elif file_path.endswith('.gguf'):
gguf_name = translate_name(key)
tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name)
else:
raise ValueError(f"Unsupported file format: {file_path}")
tensors[translate_name(key)] = tensor
if ggml_type:
ggml_type = torch.tensor(ggml_type)
ggml_key = translate_name(key)[:-7] + ".ggml_type" if translate_name(key).endswith(".weight") else translate_name(key) + ".ggml_type"
tensors[ggml_key] = ggml_type
output_file = os.path.join(output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors")
print(f"Saving non-layer tensors to {output_file}")
save_file(tensors, output_file)
print(tensors.keys())
shard_idx += 1
# Save each layer's tensors to subsequent shards
for layer_num in sorted(layer_groups.keys()):
layer_keys = layer_groups[layer_num]
tensors = {}
for key in layer_keys:
file_path = target_tensor_map[key]
tensor = None
ggml_type = None
if file_path.endswith('.safetensors'):
if file_path not in safetensors_cache:
safetensors_cache[file_path] = safe_open(file_path, framework='pt')
f = safetensors_cache[file_path]
tensor = f.get_tensor(key)
tensor_info = tensor.shape
elif file_path.endswith('.gguf'):
gguf_name = translate_name(key)
tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name)
# tensor_info = gguf_loader.tensor_info[gguf_name]
# ggml_type = gguf_loader.tensor_info[gguf_name]['ggml_type']
else:
raise ValueError(f"Unsupported file format: {file_path}")
tensors[translate_name(key)] = tensor
if ggml_type:
ggml_type = torch.tensor(ggml_type)
ggml_key = translate_name(key)[:-7] + ".ggml_type" if translate_name(key).endswith(".weight") else translate_name(key) + ".ggml_type"
tensors[ggml_key] = ggml_type
output_file = os.path.join(output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors")
print(f"Saving layer {layer_num} to {output_file}")
print(tensors.keys())
save_file(tensors, output_file)
shard_idx += 1
return
def main():
# 创建命令行参数解析器
parser = argparse.ArgumentParser(description="Read parameters from Safetensor and GGUF files")
parser.add_argument("--safetensor_path", type=str, help="Path to the Safetensor file", default="/mnt/data/model/DeepSeek-V3")
parser.add_argument("--gguf_path", type=str, help="Path to the GGUF file", default="/mnt/data/model/DeepseekV3-q4km-gguf")
parser.add_argument("--output_path", type=str, help="Path to the output file", default="/mnt/data/model/ktrans-safetensors/DeepSeek-V3-q4km-fp8")
# print all the arguments
print("All the arguments:")
print(parser.parse_args())
# 解析命令行参数
args = parser.parse_args()
safetensor_path = args.safetensor_path
gguf_path = args.gguf_path
output_path = args.output_path
target_tensor_map, gguf_loader = combine_tensor_sources(safetensor_path, gguf_path)
write_combined_tensor(target_tensor_map, output_path, gguf_loader)
return
if __name__ == "__main__":
main()