update: Qwen3 MoE model adaptation for NPU (framework) (#1706)

This commit is contained in:
Shaoxu Cheng 2025-12-11 17:07:57 +08:00 committed by GitHub
parent 53f6a6d6e1
commit adcfa9080f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 867 additions and 174 deletions

View file

@ -1,21 +0,0 @@
#!/bin/bash
set -e
source /usr/local/Ascend/ascend-toolkit/set_env.sh
# clear build dirs
rm -rf build
rm -rf *.egg-info
rm -rf csrc/build
rm -rf csrc/ktransformers_ext/build
rm -rf csrc/ktransformers_ext/cuda/build
rm -rf csrc/ktransformers_ext/cuda/dist
rm -rf csrc/ktransformers_ext/cuda/*.egg-info
rm -rf ~/.ktransformers
echo "Installing python dependencies from requirements.txt"
pip install -r requirements-local_chat.txt
pip install -r ktransformers/server/requirements.txt
echo "Installing ktransformers"
KTRANSFORMERS_FORCE_BUILD=TRUE pip install -v . --no-build-isolation
pip install third_party/custom_flashinfer/
echo "Installation completed successfully"

View file

@ -0,0 +1,89 @@
- match:
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.RotaryEmbedding
kwargs:
generate_device: "npu"
prefill_device: "npu"
- match:
name: "^lm_head$"
class: torch.nn.Linear
replace:
class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2
kwargs:
generate_device: "npu"
prefill_device: "npu"
generate_op: "KLinearTorchW8A8A2"
prefill_op: "KLinearTorchW8A8A2"
- match:
name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate)(?!.*mlp\\.gate)(?!.*mlp\\.experts).*$"
class: torch.nn.Linear
replace:
class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2
kwargs:
generate_device: "npu"
prefill_device: "npu"
generate_op: "KLinearTorchW8A8A2"
prefill_op: "KLinearTorchW8A8A2"
- match:
name: "^model\\.layers\\.(?!.*mlp\\.gate)(?!.*self_attn\\.kv_b_proj)(?!.*mlp\\.experts).*$"
class: torch.nn.Linear
replace:
class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2
kwargs:
generate_device: "npu"
prefill_device: "npu"
generate_op: "KLinearTorchW8A8A2"
prefill_op: "KLinearTorchW8A8A2"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock
replace:
class: ktransformers.operators.ascend.ascend_experts.KQwen3MoeSparseMoeBlockW8A8
kwargs:
generate_device: "npu"
prefill_device: "npu"
dump_enable: False
dump_dir: "/mnt/dump_from_mindie/dump_from_kt_moe"
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.ascend.ascend_attention.KQwen3MoeAttentionW8A8A2Serve
kwargs:
generate_device: "npu"
prefill_device: "npu"
absorb_for_prefill: False
dump_enable: False
dump_dir: "/mnt/dump_from_mindie/dump_from_kt_attn"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KQwen2MoeModel"
kwargs:
per_layer_prefill_intput_threshold: 0
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeRMSNorm
replace:
class: ktransformers.operators.ascend.ascend_layernorm.KQwen3MoeRMSNormW8A8
kwargs:
generate_device: "npu"
prefill_device: "npu"
dump_enable: False
dump_dir: "/mnt/dump_from_mindie/dump_from_kt_rms"

View file

@ -1,5 +1,5 @@
from typing import Any, AsyncIterator, List, Optional, Set
from ktransformers.models.custom_cache import KVC2StaticCache, KDeepSeekV3Cache, KGQACache
from ktransformers.models.custom_cache import KVC2StaticCache, KDeepSeekV3Cache, KGQACache, KVC2Qwen3Cache
from transformers import (
AutoTokenizer,
AutoConfig,
@ -39,6 +39,7 @@ except:
use_torch_npu = False
if use_torch_npu:
from ktransformers.models.ascend.custom_ascend_modeling_deepseek_v3 import KNPUDeepseekV3ForCausalLM
from ktransformers.models.ascend.custom_ascend_modeling_qwen3 import KNPUQwen3MoeForCausalLM
from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel, get_tensor_parallel_group, get_tensor_parallel_size
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
@ -50,7 +51,7 @@ custom_models = {
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"MixtralForCausalLM": MixtralForCausalLM,
} #TODO 独有?
}
from ktransformers.server.balance_serve.inference.model_runner import ModelRunner, get_or_create_model_runner #TODO get_or_create_model_runner npu独有
from ktransformers.models.configuration_qwen3_next import Qwen3NextConfig
from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions
@ -198,11 +199,15 @@ class Engine:
self.cache = KDeepSeekV3Cache(config, self.args.page_size)
self.model = KDeepseekV2ForCausalLM(config, self.cache)
elif config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM":
self.cache = KGQACache(config, self.args.page_size)
if config.architectures[0] == "Qwen2MoeForCausalLM":
self.model = KQwen2MoeForCausalLM(config, self.cache)
if not use_torch_npu:
self.cache = KGQACache(config, self.args.page_size)
if config.architectures[0] == "Qwen2MoeForCausalLM":
self.model = KQwen2MoeForCausalLM(config, self.cache)
else:
self.model = KQwen3MoeForCausalLM(config, self.cache)
else:
self.model = KQwen3MoeForCausalLM(config, self.cache)
self.cache = KVC2Qwen3Cache(config, args.max_batch_size, self.args.page_size)
self.model = KNPUQwen3MoeForCausalLM(config, self.cache)
elif config.architectures[0] == "SmallThinkerForCausalLM":
self.cache = KGQACache(config, self.args.page_size)
self.model = KSmallThinkerForCausalLM(config, self.cache)
@ -277,7 +282,11 @@ class Engine:
# self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
#@TODO add config
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallThinkerForCausalLM" or config.architectures[0] == "Qwen3NextForCausalLM":
self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num)
if not use_torch_npu:
self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num)
else:
# npu donnot support flash attn
self.model.init_wrapper()
else:
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
@ -322,7 +331,7 @@ class Engine:
batch_size = 0
for i in range(len(self.batch.decode_mini_batches)):
batch_size += len(self.batch.decode_mini_batches[i])
logger.debug(f"prefill batch: {len(self.batch.prefill_mini_batches)} decode batch: {len(self.batch.decode_mini_batches)} {batch_size} \n")
# logger.debug(f"prefill batch: {len(self.batch.prefill_mini_batches)} decode batch: {len(self.batch.decode_mini_batches)} {batch_size} \n")
self.model_runner.run_split(self.batch, self.query_manager)
else:
self.model_runner.run(self.batch, self.query_manager)
@ -403,9 +412,12 @@ def run_engine(args, token_queue, broadcast_endpoint, event, kvcache_event, rank
engine = Engine(args, token_queue, broadcast_endpoint, kvcache_event)
if args.use_cuda_graph:
if 'npu' in engine.device:
print(f"[WARMUP-NPU] start", flush=True)
engine.model_runner.warmup_npu()
else:
engine.model_runner.warmup()
else:
print(f"[WARMUP-NPU] skip warmup, eager mode!", flush=True)
if use_torch_npu:
args.port += torch.distributed.get_rank()
event.set()

View file

@ -175,7 +175,7 @@ class ForwardMiniBatchCombine:
class ForwardMiniBatchSplit:
# NPU流程prefill和decode无法合并需单独统计
# NPU 流程 prefill 和 decode 分开打包
prefill_batch: int
p_q_len: torch.Tensor # (bsz)
p_kv_len: torch.Tensor # (bsz)
@ -183,99 +183,261 @@ class ForwardMiniBatchSplit:
p_tokens: torch.Tensor # (sum(q_len))
p_temperatures: torch.Tensor # (bsz)
p_top_ps: torch.Tensor # (bsz)
p_block_tables: torch.Tensor # (bsz * maxBlockNum)
p_block_tables: torch.Tensor # (bsz, max_page_num)
p_logits_start: list
decode_batch: int
d_q_len: torch.Tensor
d_kv_len: torch.Tensor
d_position_ids: torch.Tensor
d_tokens: torch.Tensor
d_temperatures: torch.Tensor
d_top_ps: torch.Tensor
d_block_tables: torch.Tensor # (bsz * maxBlockNum)
d_block_tables: torch.Tensor # (bsz, max_page_num)
d_logits_start: list
chunk_size: int
is_last_prefill_chunk: bool
def __init__(self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo],
prefill_s: list[int] = None, prefill_l: list[int] = None,
device=None, page_size=256, max_page_num=64,
decode_padding_len: int = 1):
def __init__(
self,
prefill_querys_info: list[QueryInfo],
decode_querys_info: list[QueryInfo],
prefill_s: list[int] = None,
prefill_l: list[int] = None,
device=None,
page_size: int = 256,
max_page_num: int = 64,
decode_padding_len: int = 1,
):
# 统一 NPU 设备
device = torch.device('npu')
batch_decode = len(decode_querys_info)
# batch_prefill = len(prefill_querys_info)
# update valid prefill batch
new_prefill_querys_info = []
for info in prefill_querys_info:
if info is not None:
new_prefill_querys_info.append(info)
batch_prefill = len(new_prefill_querys_info)
self.num_tokens = batch_decode * decode_padding_len + sum(prefill_l)
if prefill_s is None or prefill_l is None:
raise ValueError(
"[ForwardMiniBatchSplit.__init__] prefill_s / prefill_l 不能为空chunk prefill 需要这两个参数"
)
# 过滤掉 None
new_prefill_querys_info: list[QueryInfo] = [
info for info in prefill_querys_info if info is not None
]
batch_prefill = len(new_prefill_querys_info)
batch_decode = len(decode_querys_info)
self.prefill_batch = batch_prefill
self.decode_batch = batch_decode
self.batch_size = batch_decode + batch_prefill
self.batch_size = batch_prefill + batch_decode
self.num_tokens = batch_decode * decode_padding_len + sum(prefill_l)
self.chunk_size = prefill_l[0] if prefill_l else 0
self.is_last_prefill_chunk = True
for i, q in enumerate(new_prefill_querys_info):
end_pos = prefill_s[i] + prefill_l[i]
if end_pos < q.query_length:
self.is_last_prefill_chunk = False
break
# ====================== Prefill 部分 ======================
self.p_q_len = torch.tensor([], device=device, dtype=torch.int32)
self.p_kv_len = torch.tensor([], device=device, dtype=torch.int32)
self.p_position_ids = torch.tensor([], device=device, dtype=torch.int32)
self.p_block_tables = -1 * torch.ones([self.prefill_batch, max_page_num], device=device, dtype=torch.int32)
# self.p_kv_page_offset = torch.tensor([], device=device, dtype=torch.int32)
self.p_block_tables = -1 * torch.ones(
[self.prefill_batch, max_page_num], device=device, dtype=torch.int32
)
self.p_tokens = torch.tensor([], device=device, dtype=torch.int32)
self.p_temperatures = torch.tensor([], device=device, dtype=torch.float32)
self.p_top_ps = torch.tensor([], device=device, dtype=torch.float32)
self.p_logits_start = []
self.p_logits_start: list[int] = []
for i, prefill_query_info in enumerate(new_prefill_querys_info):
prefill_kv_block_len = (prefill_query_info.active_position + prefill_l[i] + page_size - 1) // page_size if prefill_query_info is not None else 0
assert prefill_query_info.active_position == 0, '[ERROR] currently do not support prefix cache or chunk prefill in balance serving!'
# print(f"block_len: {prefill_kv_block_len}, page_size: {page_size}")
self.p_q_len = torch.concat((self.p_q_len, torch.tensor([prefill_l[i]], device=device, dtype=torch.int32)), dim=0)
self.p_kv_len = torch.concat((self.p_kv_len, torch.tensor([prefill_query_info.active_position + prefill_l[i]], device=device, dtype=torch.int32)), dim=0)
self.p_block_tables[i, :prefill_kv_block_len] = prefill_query_info.block_index[:prefill_kv_block_len]
# self.p_kv_page_offset = torch.concat((self.p_kv_page_offset, torch.tensor([(prefill_query_info.active_position + prefill_l[i]) % page_size if (prefill_query_info.active_position + prefill_l[i]) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)
self.p_position_ids = torch.concat((self.p_position_ids, torch.arange(prefill_s[i], prefill_l[i] + prefill_s[i], device=device, dtype=torch.int32)), dim=0)
self.p_tokens = torch.concat((self.p_tokens, prefill_query_info.query_tokens[prefill_s[i]:prefill_s[i] + prefill_l[i]]), dim=0)
self.p_logits_start.append(prefill_l[i] - 1 if len(self.p_logits_start) == 0 else sum(prefill_l[:i+1])-1)
qid = getattr(prefill_query_info, "id", -1)
self.p_temperatures = torch.concat((self.p_temperatures, torch.tensor([prefill_query_info.temperature], device=device, dtype=torch.float32)), dim=0)
self.p_top_ps = torch.concat((self.p_top_ps, torch.tensor([prefill_query_info.top_p], device=device, dtype=torch.float32)), dim=0)
past_len = int(prefill_query_info.active_position)
start = int(prefill_s[i]) # current chunk's start position in query_tokens
chunk_len = int(prefill_l[i])
kv_len = past_len + chunk_len
prefill_kv_block_len = (kv_len + page_size - 1) // page_size
# Q length = current chunk length
self.p_q_len = torch.concat(
(
self.p_q_len,
torch.tensor([chunk_len], device=device, dtype=torch.int32),
),
dim=0,
)
self.p_kv_len = torch.concat(
(
self.p_kv_len,
torch.tensor([kv_len], device=device, dtype=torch.int32),
),
dim=0,
)
self.p_block_tables[i, :prefill_kv_block_len] = prefill_query_info.block_index[
:prefill_kv_block_len
]
self.p_position_ids = torch.concat(
(
self.p_position_ids,
torch.arange(
start,
start + chunk_len,
device=device,
dtype=torch.int32,
),
),
dim=0,
)
self.p_tokens = torch.concat(
(
self.p_tokens,
prefill_query_info.query_tokens[start : start + chunk_len],
),
dim=0,
)
self.p_logits_start.append(
chunk_len - 1
if len(self.p_logits_start) == 0
else sum(prefill_l[: i + 1]) - 1
)
self.p_temperatures = torch.concat(
(
self.p_temperatures,
torch.tensor(
[prefill_query_info.temperature],
device=device,
dtype=torch.float32,
),
),
dim=0,
)
self.p_top_ps = torch.concat(
(
self.p_top_ps,
torch.tensor(
[prefill_query_info.top_p],
device=device,
dtype=torch.float32,
),
),
dim=0,
)
# ====================== Decode ======================
self.d_q_len = torch.tensor([], device=device, dtype=torch.int32)
self.d_kv_len = torch.tensor([], device=device, dtype=torch.int32)
self.d_position_ids = torch.tensor([], device=device, dtype=torch.int32)
self.d_block_tables = -1 * torch.ones([self.decode_batch, max_page_num], device=device, dtype=torch.int32)
# self.p_kv_page_offset = torch.tensor([], device=device, dtype=torch.int32)
self.d_block_tables = -1 * torch.ones(
[self.decode_batch, max_page_num], device=device, dtype=torch.int32
)
self.d_tokens = torch.tensor([], device=device, dtype=torch.int32)
self.d_temperatures = torch.tensor([], device=device, dtype=torch.float32)
self.d_top_ps = torch.tensor([], device=device, dtype=torch.float32)
self.d_logits_start = []
self.d_logits_start: list[int] = []
# 1 2 ...
# 1
# postion
# page table
for i, decode_query_info in enumerate(decode_querys_info):
# print("decode_query_info.active_position is ", decode_query_info.active_position)
qid = getattr(decode_query_info, "id", -1)
past_len = int(decode_query_info.active_position)
decode_kv_block_len = (past_len + decode_padding_len + page_size - 1) // page_size
decode_kv_block_len = (decode_query_info.active_position + decode_padding_len + page_size - 1) // page_size
self.d_q_len = torch.concat((self.d_q_len, torch.tensor([decode_padding_len], device=device, dtype=torch.int32)), dim=0)
self.d_kv_len = torch.concat((self.d_kv_len, torch.tensor([decode_query_info.active_position + decode_padding_len], device=device, dtype=torch.int32)), dim=0)
self.d_block_tables[i, :decode_kv_block_len] = decode_query_info.block_index[:decode_kv_block_len]
# self.d_kv_page_offset = torch.concat((self.d_kv_page_offset, torch.tensor([(decode_query_info.active_position + decode_padding_len) % page_size if (decode_query_info.active_position + 1) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)
self.d_position_ids = torch.concat((self.d_position_ids, torch.arange(decode_query_info.active_position, decode_query_info.active_position + decode_padding_len, device=device, dtype=torch.int32)), dim=0)
if decode_query_info.active_position > 0:
self.d_tokens = torch.concat((self.d_tokens, decode_query_info.query_tokens[decode_query_info.active_position:decode_query_info.active_position+decode_padding_len]), dim=0)
self.d_q_len = torch.concat(
(
self.d_q_len,
torch.tensor(
[decode_padding_len], device=device, dtype=torch.int32
),
),
dim=0,
)
self.d_kv_len = torch.concat(
(
self.d_kv_len,
torch.tensor(
[past_len + decode_padding_len],
device=device,
dtype=torch.int32,
),
),
dim=0,
)
self.d_block_tables[i, :decode_kv_block_len] = decode_query_info.block_index[
:decode_kv_block_len
]
self.d_position_ids = torch.concat(
(
self.d_position_ids,
torch.arange(
past_len,
past_len + decode_padding_len,
device=device,
dtype=torch.int32,
),
),
dim=0,
)
if past_len > 0:
self.d_tokens = torch.concat(
(
self.d_tokens,
decode_query_info.query_tokens[
past_len : past_len + decode_padding_len
],
),
dim=0,
)
else:
self.d_tokens = torch.concat((self.d_tokens, torch.tensor([0] * decode_padding_len, device=device, dtype=torch.int32)), dim=0)
self.d_logits_start.append(0 if len(self.d_logits_start) == 0 else self.d_logits_start[-1]+decode_padding_len)
# print("self.d_position_ids is ", self.d_position_ids)
self.d_tokens = torch.concat(
(
self.d_tokens,
torch.tensor(
[0] * decode_padding_len,
device=device,
dtype=torch.int32,
),
),
dim=0,
)
self.d_temperatures = torch.concat((self.d_temperatures, torch.tensor([decode_query_info.temperature], device=device, dtype=torch.float32)), dim=0)
self.d_top_ps = torch.concat((self.d_top_ps, torch.tensor([decode_query_info.top_p], device=device, dtype=torch.float32)), dim=0)
self.d_logits_start.append(
0
if len(self.d_logits_start) == 0
else self.d_logits_start[-1] + decode_padding_len
)
self.d_temperatures = torch.concat(
(
self.d_temperatures,
torch.tensor(
[decode_query_info.temperature],
device=device,
dtype=torch.float32,
),
),
dim=0,
)
self.d_top_ps = torch.concat(
(
self.d_top_ps,
torch.tensor(
[decode_query_info.top_p],
device=device,
dtype=torch.float32,
),
),
dim=0,
)
self.p_q_len = self.p_q_len.contiguous()
self.p_kv_len = self.p_kv_len.contiguous()
@ -291,7 +453,6 @@ class ForwardMiniBatchSplit:
self.d_position_ids = self.d_position_ids.reshape(self.decode_batch, -1).contiguous()
self.d_tokens = self.d_tokens.reshape(self.decode_batch, -1).contiguous()
else:
# TODO remove this
self.d_q_len = self.d_q_len.contiguous()
self.d_kv_len = self.d_kv_len.contiguous()
self.d_kv_len_list = self.d_kv_len.flatten().tolist()
@ -301,30 +462,53 @@ class ForwardMiniBatchSplit:
self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32)
def fill(self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo], prefill_s: list[int] = None, prefill_l: list[int] = None, decode_padding_len=1, device = None, page_size = 256, max_page_num=64):
device = torch.device('npu')
page_size = 128
batch_decode = len(decode_querys_info)
# batch_prefill = len(prefill_querys_info)
# update valid prefill batch
new_prefill_querys_info = []
for info in prefill_querys_info:
if info is not None:
new_prefill_querys_info.append(info)
batch_prefill = len(new_prefill_querys_info)
self.num_tokens = batch_decode + sum(prefill_l)
def fill(
self,
prefill_querys_info: list[QueryInfo],
decode_querys_info: list[QueryInfo],
prefill_s: list[int] = None,
prefill_l: list[int] = None,
decode_padding_len: int = 1,
device=None,
page_size: int = 256,
max_page_num: int = 64,
):
device = torch.device('npu')
if prefill_s is None or prefill_l is None:
raise ValueError(
"[ForwardMiniBatchSplit.fill] prefill_s / prefill_l 不能为空chunk prefill 需要这两个参数"
)
page_size = 128
new_prefill_querys_info: list[QueryInfo] = [
info for info in prefill_querys_info if info is not None
]
batch_prefill = len(new_prefill_querys_info)
batch_decode = len(decode_querys_info)
self.prefill_batch = batch_prefill
self.decode_batch = batch_decode
self.batch_size = batch_decode + batch_prefill
self.batch_size = batch_prefill + batch_decode
self.num_tokens = batch_decode * decode_padding_len + sum(prefill_l)
self.chunk_size = prefill_l[0] if prefill_l else 0
self.is_last_prefill_chunk = True
for i, q in enumerate(new_prefill_querys_info):
end_pos = prefill_s[i] + prefill_l[i]
if end_pos < q.query_length:
self.is_last_prefill_chunk = False
break
# ---------- Prefill ----------
self.p_q_len = torch.tensor([], device=device, dtype=torch.int32)
self.p_kv_len = torch.tensor([], device=device, dtype=torch.int32)
new_p_position_ids = torch.tensor([], device=device, dtype=torch.int32)
self.p_block_tables = torch.zeros([self.prefill_batch, max_page_num], device=device, dtype=torch.int32)
# self.p_kv_page_offset = torch.tensor([], device=device, dtype=torch.int32)
self.p_block_tables = torch.zeros(
[self.prefill_batch, max_page_num], device=device, dtype=torch.int32
)
new_p_tokens = torch.tensor([], device=device, dtype=torch.int32)
self.p_temperatures = torch.tensor([], device=device, dtype=torch.float32)
@ -332,58 +516,190 @@ class ForwardMiniBatchSplit:
self.p_logits_start = []
for i, prefill_query_info in enumerate(new_prefill_querys_info):
prefill_kv_block_len = (prefill_query_info.active_position + prefill_l[i] + page_size - 1) // page_size if prefill_query_info is not None else 0
assert prefill_query_info.active_position == 0, '[ERROR] currently do not support prefix cache or chunk prefill in balance serving!'
# print(f"block_len: {prefill_kv_block_len}, page_size: {page_size}")
self.p_q_len = torch.concat((self.p_q_len, torch.tensor([prefill_l[i]], device=device, dtype=torch.int32)), dim=0)
self.p_kv_len = torch.concat((self.p_kv_len, torch.tensor([prefill_query_info.active_position + prefill_l[i]], device=device, dtype=torch.int32)), dim=0)
self.p_block_tables[i, :prefill_kv_block_len] = prefill_query_info.block_index[:prefill_kv_block_len]
# self.p_kv_page_offset = torch.concat((self.p_kv_page_offset, torch.tensor([(prefill_query_info.active_position + prefill_l[i]) % page_size if (prefill_query_info.active_position + prefill_l[i]) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)
new_p_position_ids = torch.concat((new_p_position_ids, torch.arange(prefill_s[i], prefill_l[i] + prefill_s[i], device=device, dtype=torch.int32)), dim=0)
new_p_tokens = torch.concat((new_p_tokens, prefill_query_info.query_tokens[prefill_s[i]:prefill_s[i] + prefill_l[i]]), dim=0)
self.p_logits_start.append(prefill_l[i] - 1 if len(self.p_logits_start) == 0 else sum(prefill_l[:i+1])-1)
qid = getattr(prefill_query_info, "id", -1)
past_len = int(prefill_query_info.active_position)
start = int(prefill_s[i])
chunk_len = int(prefill_l[i])
self.p_temperatures = torch.concat((self.p_temperatures, torch.tensor([prefill_query_info.temperature], device=device, dtype=torch.float32)), dim=0)
self.p_top_ps = torch.concat((self.p_top_ps, torch.tensor([prefill_query_info.top_p], device=device, dtype=torch.float32)), dim=0)
kv_len = past_len + chunk_len
prefill_kv_block_len = (kv_len + page_size - 1) // page_size
self.d_q_len = torch.zeros([1] * self.decode_batch, device=device, dtype=torch.int32)
self.p_q_len = torch.concat(
(
self.p_q_len,
torch.tensor([chunk_len], device=device, dtype=torch.int32),
),
dim=0,
)
self.p_kv_len = torch.concat(
(
self.p_kv_len,
torch.tensor([kv_len], device=device, dtype=torch.int32),
),
dim=0,
)
self.p_block_tables[i, :prefill_kv_block_len] = prefill_query_info.block_index[
:prefill_kv_block_len
]
new_p_position_ids = torch.concat(
(
new_p_position_ids,
torch.arange(
start,
start + chunk_len,
device=device,
dtype=torch.int32,
),
),
dim=0,
)
new_p_tokens = torch.concat(
(
new_p_tokens,
prefill_query_info.query_tokens[start : start + chunk_len],
),
dim=0,
)
self.p_logits_start.append(
chunk_len - 1 if len(self.p_logits_start) == 0 else sum(prefill_l[: i + 1]) - 1
)
self.p_temperatures = torch.concat(
(
self.p_temperatures,
torch.tensor(
[prefill_query_info.temperature],
device=device,
dtype=torch.float32,
),
),
dim=0,
)
self.p_top_ps = torch.concat(
(
self.p_top_ps,
torch.tensor(
[prefill_query_info.top_p],
device=device,
dtype=torch.float32,
),
),
dim=0,
)
if new_p_position_ids.numel() > 0:
self.p_position_ids = new_p_position_ids.contiguous()
if new_p_tokens.numel() > 0:
self.p_tokens = new_p_tokens.contiguous()
# ---------- Decode ----------
self.d_q_len = torch.zeros(
[1] * self.decode_batch, device=device, dtype=torch.int32
)
self.d_kv_len = torch.tensor([], device=device, dtype=torch.int32)
new_d_position_ids = torch.tensor([], device=device, dtype=torch.int32)
new_d_block_tables = -1 * torch.ones([self.decode_batch, max_page_num], device=device, dtype=torch.int32)
# self.p_kv_page_offset = torch.tensor([], device=device, dtype=torch.int32)
new_d_block_tables = -1 * torch.ones(
[self.decode_batch, max_page_num], device=device, dtype=torch.int32
)
new_d_tokens = torch.tensor([], device=device, dtype=torch.int32)
self.d_logits_start = []
self.d_logits_start = []
self.d_temperatures = torch.tensor([], device=device, dtype=torch.float32)
self.d_top_ps = torch.tensor([], device=device, dtype=torch.float32)
for i, decode_query_info in enumerate(decode_querys_info):
decode_kv_block_len = (decode_query_info.active_position + 1 + page_size - 1) // page_size
self.d_kv_len = torch.concat((self.d_kv_len, torch.tensor([decode_query_info.active_position + 1], device=device, dtype=torch.int32)), dim=0)
# print("fill self.d_block_tables is ", self.d_block_tables)
new_d_block_tables[i, :decode_kv_block_len] = decode_query_info.block_index[:decode_kv_block_len]
# print("decode_query_info.block_index[:decode_kv_block_len] is ", decode_query_info.block_index[:decode_kv_block_len])
# self.d_kv_page_offset = torch.concat((self.d_kv_page_offset, torch.tensor([(decode_query_info.active_position + 1) % page_size if (decode_query_info.active_position + 1) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)
new_d_position_ids = torch.concat((new_d_position_ids, torch.arange(decode_query_info.active_position, decode_query_info.active_position + 1, device=device, dtype=torch.int32)), dim=0)
# print("decode_query_info.active_position is ", decode_query_info.active_position)
qid = getattr(decode_query_info, "id", -1)
past_len = int(decode_query_info.active_position)
decode_kv_block_len = (past_len + decode_padding_len + page_size - 1) // page_size
if decode_query_info.active_position > 0:
new_d_tokens = torch.concat((new_d_tokens, decode_query_info.query_tokens[decode_query_info.active_position:decode_query_info.active_position+1]), dim=0)
self.d_kv_len = torch.concat(
(
self.d_kv_len,
torch.tensor(
[past_len + decode_padding_len],
device=device,
dtype=torch.int32,
),
),
dim=0,
)
new_d_block_tables[i, :decode_kv_block_len] = decode_query_info.block_index[
:decode_kv_block_len
]
new_d_position_ids = torch.concat(
(
new_d_position_ids,
torch.arange(
past_len,
past_len + decode_padding_len,
device=device,
dtype=torch.int32,
),
),
dim=0,
)
if past_len > 0:
new_d_tokens = torch.concat(
(
new_d_tokens,
decode_query_info.query_tokens[
past_len : past_len + decode_padding_len
],
),
dim=0,
)
else:
new_d_tokens = torch.concat((new_d_tokens, torch.tensor([0], device=device, dtype=torch.int32)), dim=0)
self.d_logits_start.append(0 if len(self.d_logits_start) == 0 else self.d_logits_start[-1]+1)
new_d_tokens = torch.concat(
(
new_d_tokens,
torch.tensor(
[0] * decode_padding_len,
device=device,
dtype=torch.int32,
),
),
dim=0,
)
self.d_temperatures = torch.concat((self.d_temperatures, torch.tensor([decode_query_info.temperature], device=device, dtype=torch.float32)), dim=0)
self.d_top_ps = torch.concat((self.d_top_ps, torch.tensor([decode_query_info.top_p], device=device, dtype=torch.float32)), dim=0)
self.d_logits_start.append(
0
if len(self.d_logits_start) == 0
else self.d_logits_start[-1] + decode_padding_len
)
self.d_temperatures = torch.concat(
(
self.d_temperatures,
torch.tensor(
[decode_query_info.temperature],
device=device,
dtype=torch.float32,
),
),
dim=0,
)
self.d_top_ps = torch.concat(
(
self.d_top_ps,
torch.tensor(
[decode_query_info.top_p],
device=device,
dtype=torch.float32,
),
),
dim=0,
)
if len(decode_querys_info) > 1:
self.d_position_ids[i].copy_(new_d_position_ids[i])
# self.d_position_ids[i:].zero_()
self.d_tokens[i].copy_(new_d_tokens[i])
self.d_block_tables[i].copy_(new_d_block_tables[i])
else:
self.d_position_ids[:new_d_position_ids.size(0)].copy_(new_d_position_ids)
# self.d_position_ids[new_d_position_ids.size(0):].zero_()
self.d_tokens[:new_d_tokens.size(0)].copy_(new_d_tokens)
self.d_block_tables[0].copy_(new_d_block_tables[0])
@ -391,15 +707,10 @@ class ForwardMiniBatchSplit:
self.p_q_len = self.p_q_len.contiguous()
self.p_kv_len = self.p_kv_len.contiguous()
self.p_block_tables = self.p_block_tables.contiguous()
# self.p_position_ids = self.p_position_ids.contiguous()
# self.p_tokens = self.p_tokens.contiguous()
self.d_q_len = self.d_q_len.contiguous()
self.d_kv_len = self.d_kv_len.contiguous()
self.d_kv_len_list = self.d_kv_len.flatten().tolist()
# self.d_block_tables = self.d_block_tables.contiguous()
# self.d_position_ids = self.d_position_ids.contiguous()
# self.d_tokens = self.d_tokens.contiguous()
self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32)
@ -407,12 +718,13 @@ class ForwardMiniBatchSplit:
def __str__(self):
ret = ''
ret += f'=======Prefill forward info:\n'
ret += f'batch: {self.prefill_batch=}, qLen: {self.p_q_len=}, kvLen: {self.p_kv_len=}\n'
ret += f'tokens: {self.p_tokens=}, posIdx: {self.p_position_ids=}, block_tables: {self.p_block_tables=}\n'
ret += f'=======Decode forward info:\n'
ret += f'batch: {self.decode_batch=}, qLen: {self.d_q_len=}, kvLen: {self.d_kv_len=}\n'
ret += f'tokens: {self.d_tokens=}, posIdx: {self.d_position_ids=}, block_tables: {self.d_block_tables=}\n'
ret += '=======Prefill forward info:\n'
ret += f'batch: {self.prefill_batch}, qLen: {self.p_q_len}, kvLen: {self.p_kv_len}\n'
ret += f'tokens: {self.p_tokens}, posIdx: {self.p_position_ids}, block_tables: {self.p_block_tables}\n'
ret += '=======Decode forward info:\n'
ret += f'batch: {self.decode_batch}, qLen: {self.d_q_len}, kvLen: {self.d_kv_len}\n'
ret += f'tokens: {self.d_tokens}, posIdx: {self.d_position_ids}, block_tables: {self.d_block_tables}\n'
ret += f'chunk_size={self.chunk_size}, is_last_prefill_chunk={self.is_last_prefill_chunk}\n'
return ret
@ -437,14 +749,15 @@ class ForwardBatchInput:
prefill_l = []
decode_querys_info = []
self.batch_size = 1
for (id, s, l) in prefill_minibatches:
prefill_querys_info.append(query_manager.query_map[id])
for (qid, s, l) in prefill_minibatches:
prefill_querys_info.append(query_manager.query_map[qid])
prefill_s.append(s)
prefill_l.append(l)
for decode_batch_idx in decode_mini_batches:
if query_manager.query_map[decode_batch_idx].decode_start_time is None:
query_manager.query_map[decode_batch_idx].decode_start_time =time.time()
decode_querys_info.append(query_manager.query_map[decode_batch_idx])
for decode_qid in decode_mini_batches:
qinfo = query_manager.query_map[decode_qid]
if qinfo.decode_start_time is None:
qinfo.decode_start_time = time.time()
decode_querys_info.append(qinfo)
if use_torch_npu:
minibatch = ForwardMiniBatchSplit(prefill_querys_info, decode_querys_info, prefill_s, prefill_l, device = query_manager.device, page_size = query_manager.page_size)
@ -493,7 +806,7 @@ class ForwardBatchInput:
decode_querys_info.append(query_info)
if prefill_query_length*Config().max_prefill_batch_size + len(decode_querys_info) < cuda_lens:
if prefill_query_length * Config().max_prefill_batch_size + len(decode_querys_info) < cuda_lens:
decode_querys_info.append(query_info)
if use_torch_npu:
instance.minibatch = ForwardMiniBatchSplit(prefill_query_info, decode_querys_info, [0, 0],

View file

@ -44,7 +44,8 @@ try:
import torch_npu
use_torch_npu = torch_npu.npu.is_available()
from ktransformers.models.ascend.custom_ascend_modeling_deepseek_v3 import KNPUDeepseekV3ForCausalLM
from ktransformers.models.custom_cache import KVC2StaticCache
from ktransformers.models.ascend.custom_ascend_modeling_qwen3 import KNPUQwen3MoeForCausalLM
from ktransformers.models.custom_cache import KVC2StaticCache, KVC2Qwen3Cache
except:
use_torch_npu = False
@ -70,8 +71,8 @@ class ModelRunner:
if not use_torch_npu:
model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM | KSmallThinkerForCausalLM | KGlm4MoeForCausalLM | KQwen3NextForCausalLM
else:
model: KNPUDeepseekV3ForCausalLM
cache: KVC2StaticCache #TODO 只有npu适配的代码里用到规避
model: KNPUDeepseekV3ForCausalLM | KNPUQwen3MoeForCausalLM
cache: KVC2StaticCache | KVC2Qwen3Cache
input: ForwardBatchInput | list[ForwardBatchInput]
output: ForwardBatchOutput
@ -210,7 +211,17 @@ class ModelRunner:
utils._USE_NPU_GRAPH = True
print("self.features_buf[npu_graph_idx] is ", self.features_buf[npu_graph_idx])
with torch.npu.graph(self.graphs[npu_graph_idx], pool=self.graph_memory_pool, stream=self.stream, auto_dispatch_capture=True):
self.outputs_buf[npu_graph_idx] = self.model(self.input_decode[npu_graph_idx], self.features_buf[npu_graph_idx], self.cache, None, None, self.page_idx_buf[npu_graph_idx], self.page_offset_buf[npu_graph_idx], self.position_ids_buf[npu_graph_idx], self.block_tables_buf[npu_graph_idx], cuda_graph_idx=npu_graph_idx, is_prefill=False)
self.outputs_buf[npu_graph_idx] = self.model(
self.input_decode[npu_graph_idx],
self.features_buf[npu_graph_idx],
self.cache, None, None,
self.page_idx_buf[npu_graph_idx],
self.page_offset_buf[npu_graph_idx],
self.position_ids_buf[npu_graph_idx],
self.block_tables_buf[npu_graph_idx],
cuda_graph_idx=npu_graph_idx,
is_prefill=False
)
self.graph_memory_pool = self.graphs[npu_graph_idx].pool()
utils._USE_NPU_GRAPH = False
@ -340,7 +351,6 @@ class ModelRunner:
def _run_infer_stage(is_prefill=True):
if "npu" in self.device:
cuda_graph_idx = batch_size_decode
# print("batch_size is ", batch_size)
if is_prefill == False:
if cuda_graph_idx != -1 and self.use_cuda_graph:
self.features = self.model.batch_embeddings(self.input_decode[cuda_graph_idx], device=self.device, is_prefill=is_prefill)
@ -370,7 +380,6 @@ class ModelRunner:
self.replay(cuda_graph_idx)
new_output = ForwardBatchOutput()
# bsz = self.outputs_buf[cuda_graph_idx].logits[0][self.input_decode[cuda_graph_idx].minibatch.d_logits_start].size(0)
for i in range(num_tokens):
new_output.top_ps.append(self.input_decode[cuda_graph_idx].minibatch.d_top_ps[i])
new_output.temperatures.append(self.input_decode[cuda_graph_idx].minibatch.d_temperatures[i])
@ -389,13 +398,11 @@ class ModelRunner:
bsz = len(new_output.logits)
if is_prefill:
for i in range(bsz):
# new_output.logits[i] = new_output.logits[i][self.input.minibatch.p_logits_start[i]:, :] # slice prefill seq[-1]
new_output.logits[i] = new_output.logits[i][-1:, :] # batched tensor do not need location
new_output.top_ps.append(self.input.minibatch.p_top_ps[i])
new_output.temperatures.append(self.input.minibatch.p_temperatures[i])
else:
for i in range(bsz):
# new_output.logits[i] = new_output.logits[i][self.input.minibatch.d_logits_start[i]:, :]
new_output.top_ps.append(self.input.minibatch.d_top_ps[i])
new_output.temperatures.append(self.input.minibatch.d_temperatures[i])

View file

@ -1,5 +1,5 @@
torch >= 2.3.0
transformers == 4.51.3
transformers >= 4.51.3
fastapi >= 0.111.0
langchain >= 0.2.0
blessed >= 1.20.0

View file

@ -1,3 +1,19 @@
# coding=utf-8
# Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved.
# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from datetime import timedelta
@ -148,24 +164,35 @@ def get_safetensors_cut_weight(name: str, weights: torch.Tensor):
def get_absort_weight(model, config):
# 新增q_absorb out_absorb属性
local_rank = torch.distributed.get_rank()
if not dist.is_initialized():
return
local_rank = dist.get_rank()
tp = get_tensor_parallel_size()
local_rank %= tp
tp_heads = config.num_attention_heads // tp
for i in range(config.num_hidden_layers):
self = model.model.layers[i].self_attn
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
kv_b_proj = self.kv_b_proj.weight.view(config.num_attention_heads, -1, self.kv_lora_rank)
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].clone()
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].clone()
q_absorb = q_absorb[local_rank * tp_heads: (local_rank + 1) * tp_heads, :, :].contiguous()
out_absorb = out_absorb[local_rank * tp_heads: (local_rank + 1) * tp_heads, :, :].contiguous()
out_absorb = out_absorb.transpose(1, 2).contiguous()
setattr(self, "q_absorb", q_absorb)
setattr(self, "out_absorb", out_absorb)
del self.orig_module.kv_b_proj
torch.distributed.barrier(get_tensor_parallel_group())
attn = model.model.layers[i].self_attn
if hasattr(attn, "q_absorb") and hasattr(attn, "out_absorb"):
continue
if not (hasattr(attn, "kv_b_proj")
and hasattr(attn, "kv_lora_rank")
and hasattr(attn, "qk_nope_head_dim")):
continue
kv_b_proj = attn.kv_b_proj.weight.view(config.num_attention_heads, -1, attn.kv_lora_rank)
q_absorb = kv_b_proj[:, :attn.qk_nope_head_dim, :].clone()
out_absorb = kv_b_proj[:, attn.qk_nope_head_dim:, :].clone()
q_absorb = q_absorb[local_rank * tp_heads: (local_rank + 1) * tp_heads, :, :].contiguous()
out_absorb = out_absorb[local_rank * tp_heads: (local_rank + 1) * tp_heads, :, :].contiguous()
out_absorb = out_absorb.transpose(1, 2).contiguous()
setattr(attn, "q_absorb", q_absorb)
setattr(attn, "out_absorb", out_absorb)
if hasattr(attn, "orig_module") and hasattr(attn.orig_module, "kv_b_proj"):
del attn.orig_module.kv_b_proj
dist.barrier(get_tensor_parallel_group())
def allredeuce_warpper(func):

View file

@ -666,6 +666,33 @@ def translate_name_to_gguf(name):
name = translate_name_to_gguf_mixtral(name)
if ".ffn_gate_exp." in name:
name = name.replace(".ffn_gate_exp.", ".ffn_gate_exps.")
if ".ffn_up_exp." in name:
name = name.replace(".ffn_up_exp.", ".ffn_up_exps.")
if ".ffn_down_exp." in name:
name = name.replace(".ffn_down_exp.", ".ffn_down_exps.")
m = re.match(r"model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)", name)
if m:
layer, expert, proj = m.groups()
if proj == "gate_proj":
return f"blk.{layer}.{expert}.ffn_gate_exps"
elif proj == "up_proj":
return f"blk.{layer}.{expert}.ffn_up_exps"
else:
return f"blk.{layer}.{expert}.ffn_down_exps"
m = re.match(r"blk\.(\d+)\.mlp\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)", name)
if m:
layer, expert, proj = m.groups()
if proj == "gate_proj":
return f"blk.{layer}.{expert}.ffn_gate_exps"
elif proj == "up_proj":
return f"blk.{layer}.{expert}.ffn_up_exps"
else:
return f"blk.{layer}.{expert}.ffn_down_exps"
name = name.replace("lm_head.", "output.")
name = name.replace("model.embed_tokens.", "token_embd.")
name = name.replace("model.norm.", "output_norm.")

View file

@ -104,7 +104,11 @@ class SafeTensorLoader(ModelLoader):
f = self.file_handle_map.get(file)
if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(key)
if use_torch_npu:
tensor = f.get_tensor(key).to(torch.float16)
else:
tensor = f.get_tensor(key)
return tensor.to(device)
def load_experts(self, key: str, device: str="cpu"):

View file

@ -0,0 +1,235 @@
# coding=utf-8
# Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved.
# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import argparse
import torch
from ktransformers.util.custom_loader import GGUFLoader, translate_name_to_gguf
from safetensors import safe_open
from safetensors.torch import save_file
import re
from collections import defaultdict
def read_safetensor_keys_from_folder(folder_path) -> dict:
if not os.path.exists(folder_path):
raise FileNotFoundError(f"Safetensors dir not found: {folder_path}")
if os.path.isfile(folder_path):
folder_path = os.path.dirname(folder_path)
key_to_file_map = {}
found_safetensor = False
for root, dirs, files in os.walk(folder_path):
files = sorted(files)
for file in files:
if not file.endswith(".safetensors"):
continue
found_safetensor = True
file_path = os.path.join(root, file)
try:
with safe_open(file_path, framework="pt") as f:
for key in f.keys():
key_to_file_map[key] = file_path
except Exception as e:
print(f"Error reading Safetensor file {file_path}: {e}")
if not found_safetensor:
raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
return key_to_file_map
# 可选:如果你希望对某些非 MoE tensor 也用 GGUF可以把关键子串填到下面这个列表里
tensor_from_gguf = [] # e.g. ["self_attn.q_proj.weight"]
def translate_name(name: str) -> str:
name = translate_name_to_gguf(name)
name = name.replace(".up_proj.", ".ffn_up_exps.")
name = name.replace(".down_proj.", ".ffn_down_exps.")
name = name.replace(".gate_proj.", ".ffn_gate_exps.")
name = name.replace(".ffn_gate_inp.e_score_correction_bias", ".exp_probs_b.bias")
return name
def combine_tensor_sources(safetensor_path: str, gguf_path: str):
gguf_loader = GGUFLoader(gguf_path)
gguf_tensor_file_map = gguf_loader.tensor_file_map
safetensor_tensor_file_map = read_safetensor_keys_from_folder(safetensor_path)
target_tensor_map = {}
for key, st_file in safetensor_tensor_file_map.items():
if ".mlp.experts." in key and key.endswith(".weight"):
parts = key.split(".")
if len(parts) < 8:
raise ValueError(f"Unexpected MoE expert key format: {key}")
norm_key = ".".join(parts[:5] + parts[-2:])
gguf_name = translate_name(norm_key)
if gguf_name not in gguf_tensor_file_map:
raise KeyError(
f"[MoE] GGUF tensor not found for safetensors key {key} -> {gguf_name}"
)
target_tensor_map[norm_key] = gguf_tensor_file_map[gguf_name]
continue
if any(tag in key for tag in tensor_from_gguf):
gguf_name = translate_name(key)
if gguf_name not in gguf_tensor_file_map:
raise KeyError(
f"[Non-MoE] GGUF tensor not found for safetensors key {key} -> {gguf_name}"
)
target_tensor_map[key] = gguf_tensor_file_map[gguf_name]
else:
target_tensor_map[key] = st_file
return target_tensor_map, gguf_loader
def write_combined_tensor(target_tensor_map: dict, output_path: str, gguf_loader: GGUFLoader):
os.makedirs(output_path, exist_ok=True)
safetensors_cache = {}
layer_groups = defaultdict(list)
non_layer_keys = []
layer_pattern = re.compile(r"\.layers\.(\d+)\.")
for key in target_tensor_map:
m = layer_pattern.search(key)
if m:
layer_num = int(m.group(1))
layer_groups[layer_num].append(key)
else:
non_layer_keys.append(key)
total_shards = len(layer_groups) + (1 if non_layer_keys else 0) - 1
if total_shards <= 0:
raise ValueError("No tensors to save")
shard_idx = 0
if non_layer_keys:
tensors = {}
for key in non_layer_keys:
file_path = target_tensor_map[key]
tensor = None
ggml_type = None
if file_path.endswith(".safetensors"):
if file_path not in safetensors_cache:
safetensors_cache[file_path] = safe_open(file_path, framework="pt")
f = safetensors_cache[file_path]
tensor = f.get_tensor(key)
elif file_path.endswith(".gguf"):
gguf_name = translate_name(key)
tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name)
else:
raise ValueError(f"Unsupported file format: {file_path}")
out_key = translate_name(key)
tensors[out_key] = tensor
if ggml_type is not None:
ggml_type = torch.tensor(ggml_type)
if out_key.endswith(".weight"):
ggml_key = out_key[:-7] + ".ggml_type"
else:
ggml_key = out_key + ".ggml_type"
tensors[ggml_key] = ggml_type
output_file = os.path.join(
output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors"
)
print(f"[WRITE] Saving non-layer tensors to {output_file}")
save_file(tensors, output_file)
shard_idx += 1
for layer_num in sorted(layer_groups.keys()):
layer_keys = layer_groups[layer_num]
tensors = {}
for key in layer_keys:
file_path = target_tensor_map[key]
tensor = None
ggml_type = None
if file_path.endswith(".safetensors"):
if file_path not in safetensors_cache:
safetensors_cache[file_path] = safe_open(file_path, framework="pt")
f = safetensors_cache[file_path]
tensor = f.get_tensor(key)
elif file_path.endswith(".gguf"):
gguf_name = translate_name(key)
tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name)
else:
raise ValueError(f"Unsupported file format: {file_path}")
out_key = translate_name(key)
tensors[out_key] = tensor
if ggml_type is not None:
ggml_type = torch.tensor(ggml_type)
if out_key.endswith(".weight"):
ggml_key = out_key[:-7] + ".ggml_type"
else:
ggml_key = out_key + ".ggml_type"
tensors[ggml_key] = ggml_type
output_file = os.path.join(
output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors"
)
print(f"[WRITE] Saving layer {layer_num} to {output_file}")
save_file(tensors, output_file)
shard_idx += 1
def main():
parser = argparse.ArgumentParser(
description="Merge FP8 safetensors and GGUF tensors for Qwen3-30B-A3B"
)
parser.add_argument(
"--safetensor_path",
type=str,
help="Path to the FP8 Safetensor folder",
default="/mnt/data/model/Qwen3-30B-A3B-FP8",
)
parser.add_argument(
"--gguf_path",
type=str,
help="Path to the GGUF file or folder",
default="/mnt/data/model/Qwen3-30B-A3B-GGUF",
)
parser.add_argument(
"--output_path",
type=str,
help="Path to the output safetensors folder",
default="/mnt/data/model/ktrans-safetensors/Qwen3-30B-A3B-q4km-fp8",
)
args = parser.parse_args()
print("[ARGS]", args)
safetensor_path = args.safetensor_path
gguf_path = args.gguf_path
output_path = args.output_path
target_tensor_map, gguf_loader = combine_tensor_sources(safetensor_path, gguf_path)
write_combined_tensor(target_tensor_map, output_path, gguf_loader)
if __name__ == "__main__":
main()