mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-26 10:50:59 +00:00
update: Qwen3 MoE model adaptation for NPU (framework) (#1706)
This commit is contained in:
parent
53f6a6d6e1
commit
adcfa9080f
10 changed files with 867 additions and 174 deletions
|
|
@ -1,21 +0,0 @@
|
|||
#!/bin/bash
|
||||
set -e
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
|
||||
# clear build dirs
|
||||
rm -rf build
|
||||
rm -rf *.egg-info
|
||||
rm -rf csrc/build
|
||||
rm -rf csrc/ktransformers_ext/build
|
||||
rm -rf csrc/ktransformers_ext/cuda/build
|
||||
rm -rf csrc/ktransformers_ext/cuda/dist
|
||||
rm -rf csrc/ktransformers_ext/cuda/*.egg-info
|
||||
rm -rf ~/.ktransformers
|
||||
echo "Installing python dependencies from requirements.txt"
|
||||
pip install -r requirements-local_chat.txt
|
||||
pip install -r ktransformers/server/requirements.txt
|
||||
echo "Installing ktransformers"
|
||||
KTRANSFORMERS_FORCE_BUILD=TRUE pip install -v . --no-build-isolation
|
||||
pip install third_party/custom_flashinfer/
|
||||
|
||||
echo "Installation completed successfully"
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
- match:
|
||||
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.RotaryEmbedding
|
||||
kwargs:
|
||||
generate_device: "npu"
|
||||
prefill_device: "npu"
|
||||
|
||||
- match:
|
||||
name: "^lm_head$"
|
||||
class: torch.nn.Linear
|
||||
replace:
|
||||
class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2
|
||||
kwargs:
|
||||
generate_device: "npu"
|
||||
prefill_device: "npu"
|
||||
generate_op: "KLinearTorchW8A8A2"
|
||||
prefill_op: "KLinearTorchW8A8A2"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate)(?!.*mlp\\.gate)(?!.*mlp\\.experts).*$"
|
||||
class: torch.nn.Linear
|
||||
replace:
|
||||
class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2
|
||||
kwargs:
|
||||
generate_device: "npu"
|
||||
prefill_device: "npu"
|
||||
generate_op: "KLinearTorchW8A8A2"
|
||||
prefill_op: "KLinearTorchW8A8A2"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*mlp\\.gate)(?!.*self_attn\\.kv_b_proj)(?!.*mlp\\.experts).*$"
|
||||
class: torch.nn.Linear
|
||||
replace:
|
||||
class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2
|
||||
kwargs:
|
||||
generate_device: "npu"
|
||||
prefill_device: "npu"
|
||||
generate_op: "KLinearTorchW8A8A2"
|
||||
prefill_op: "KLinearTorchW8A8A2"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp$"
|
||||
class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock
|
||||
replace:
|
||||
class: ktransformers.operators.ascend.ascend_experts.KQwen3MoeSparseMoeBlockW8A8
|
||||
kwargs:
|
||||
generate_device: "npu"
|
||||
prefill_device: "npu"
|
||||
dump_enable: False
|
||||
dump_dir: "/mnt/dump_from_mindie/dump_from_kt_moe"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.ascend.ascend_attention.KQwen3MoeAttentionW8A8A2Serve
|
||||
kwargs:
|
||||
generate_device: "npu"
|
||||
prefill_device: "npu"
|
||||
absorb_for_prefill: False
|
||||
dump_enable: False
|
||||
dump_dir: "/mnt/dump_from_mindie/dump_from_kt_attn"
|
||||
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KQwen2MoeModel"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0
|
||||
|
||||
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
- match:
|
||||
class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeRMSNorm
|
||||
replace:
|
||||
class: ktransformers.operators.ascend.ascend_layernorm.KQwen3MoeRMSNormW8A8
|
||||
kwargs:
|
||||
generate_device: "npu"
|
||||
prefill_device: "npu"
|
||||
dump_enable: False
|
||||
dump_dir: "/mnt/dump_from_mindie/dump_from_kt_rms"
|
||||
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Any, AsyncIterator, List, Optional, Set
|
||||
from ktransformers.models.custom_cache import KVC2StaticCache, KDeepSeekV3Cache, KGQACache
|
||||
from ktransformers.models.custom_cache import KVC2StaticCache, KDeepSeekV3Cache, KGQACache, KVC2Qwen3Cache
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
|
|
@ -39,6 +39,7 @@ except:
|
|||
use_torch_npu = False
|
||||
if use_torch_npu:
|
||||
from ktransformers.models.ascend.custom_ascend_modeling_deepseek_v3 import KNPUDeepseekV3ForCausalLM
|
||||
from ktransformers.models.ascend.custom_ascend_modeling_qwen3 import KNPUQwen3MoeForCausalLM
|
||||
from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel, get_tensor_parallel_group, get_tensor_parallel_size
|
||||
|
||||
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
|
||||
|
|
@ -50,7 +51,7 @@ custom_models = {
|
|||
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
|
||||
"LlamaForCausalLM": LlamaForCausalLM,
|
||||
"MixtralForCausalLM": MixtralForCausalLM,
|
||||
} #TODO 独有?
|
||||
}
|
||||
from ktransformers.server.balance_serve.inference.model_runner import ModelRunner, get_or_create_model_runner #TODO get_or_create_model_runner npu独有?
|
||||
from ktransformers.models.configuration_qwen3_next import Qwen3NextConfig
|
||||
from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions
|
||||
|
|
@ -198,11 +199,15 @@ class Engine:
|
|||
self.cache = KDeepSeekV3Cache(config, self.args.page_size)
|
||||
self.model = KDeepseekV2ForCausalLM(config, self.cache)
|
||||
elif config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||
self.cache = KGQACache(config, self.args.page_size)
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
self.model = KQwen2MoeForCausalLM(config, self.cache)
|
||||
if not use_torch_npu:
|
||||
self.cache = KGQACache(config, self.args.page_size)
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
self.model = KQwen2MoeForCausalLM(config, self.cache)
|
||||
else:
|
||||
self.model = KQwen3MoeForCausalLM(config, self.cache)
|
||||
else:
|
||||
self.model = KQwen3MoeForCausalLM(config, self.cache)
|
||||
self.cache = KVC2Qwen3Cache(config, args.max_batch_size, self.args.page_size)
|
||||
self.model = KNPUQwen3MoeForCausalLM(config, self.cache)
|
||||
elif config.architectures[0] == "SmallThinkerForCausalLM":
|
||||
self.cache = KGQACache(config, self.args.page_size)
|
||||
self.model = KSmallThinkerForCausalLM(config, self.cache)
|
||||
|
|
@ -277,7 +282,11 @@ class Engine:
|
|||
# self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
|
||||
#@TODO add config
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallThinkerForCausalLM" or config.architectures[0] == "Qwen3NextForCausalLM":
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num)
|
||||
if not use_torch_npu:
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num)
|
||||
else:
|
||||
# npu donnot support flash attn
|
||||
self.model.init_wrapper()
|
||||
else:
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
|
||||
|
||||
|
|
@ -322,7 +331,7 @@ class Engine:
|
|||
batch_size = 0
|
||||
for i in range(len(self.batch.decode_mini_batches)):
|
||||
batch_size += len(self.batch.decode_mini_batches[i])
|
||||
logger.debug(f"prefill batch: {len(self.batch.prefill_mini_batches)} decode batch: {len(self.batch.decode_mini_batches)} {batch_size} \n")
|
||||
# logger.debug(f"prefill batch: {len(self.batch.prefill_mini_batches)} decode batch: {len(self.batch.decode_mini_batches)} {batch_size} \n")
|
||||
self.model_runner.run_split(self.batch, self.query_manager)
|
||||
else:
|
||||
self.model_runner.run(self.batch, self.query_manager)
|
||||
|
|
@ -403,9 +412,12 @@ def run_engine(args, token_queue, broadcast_endpoint, event, kvcache_event, rank
|
|||
engine = Engine(args, token_queue, broadcast_endpoint, kvcache_event)
|
||||
if args.use_cuda_graph:
|
||||
if 'npu' in engine.device:
|
||||
print(f"[WARMUP-NPU] start", flush=True)
|
||||
engine.model_runner.warmup_npu()
|
||||
else:
|
||||
engine.model_runner.warmup()
|
||||
else:
|
||||
print(f"[WARMUP-NPU] skip warmup, eager mode!", flush=True)
|
||||
if use_torch_npu:
|
||||
args.port += torch.distributed.get_rank()
|
||||
event.set()
|
||||
|
|
|
|||
|
|
@ -175,7 +175,7 @@ class ForwardMiniBatchCombine:
|
|||
|
||||
|
||||
class ForwardMiniBatchSplit:
|
||||
# NPU流程prefill和decode无法合并,需单独统计
|
||||
# NPU 流程 prefill 和 decode 分开打包
|
||||
prefill_batch: int
|
||||
p_q_len: torch.Tensor # (bsz)
|
||||
p_kv_len: torch.Tensor # (bsz)
|
||||
|
|
@ -183,99 +183,261 @@ class ForwardMiniBatchSplit:
|
|||
p_tokens: torch.Tensor # (sum(q_len))
|
||||
p_temperatures: torch.Tensor # (bsz)
|
||||
p_top_ps: torch.Tensor # (bsz)
|
||||
p_block_tables: torch.Tensor # (bsz * maxBlockNum)
|
||||
p_block_tables: torch.Tensor # (bsz, max_page_num)
|
||||
p_logits_start: list
|
||||
|
||||
decode_batch: int
|
||||
d_q_len: torch.Tensor
|
||||
d_kv_len: torch.Tensor
|
||||
d_position_ids: torch.Tensor
|
||||
d_tokens: torch.Tensor
|
||||
d_temperatures: torch.Tensor
|
||||
d_top_ps: torch.Tensor
|
||||
d_block_tables: torch.Tensor # (bsz * maxBlockNum)
|
||||
d_block_tables: torch.Tensor # (bsz, max_page_num)
|
||||
d_logits_start: list
|
||||
|
||||
chunk_size: int
|
||||
is_last_prefill_chunk: bool
|
||||
|
||||
def __init__(self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo],
|
||||
prefill_s: list[int] = None, prefill_l: list[int] = None,
|
||||
device=None, page_size=256, max_page_num=64,
|
||||
decode_padding_len: int = 1):
|
||||
def __init__(
|
||||
self,
|
||||
prefill_querys_info: list[QueryInfo],
|
||||
decode_querys_info: list[QueryInfo],
|
||||
prefill_s: list[int] = None,
|
||||
prefill_l: list[int] = None,
|
||||
device=None,
|
||||
page_size: int = 256,
|
||||
max_page_num: int = 64,
|
||||
decode_padding_len: int = 1,
|
||||
):
|
||||
# 统一 NPU 设备
|
||||
device = torch.device('npu')
|
||||
batch_decode = len(decode_querys_info)
|
||||
# batch_prefill = len(prefill_querys_info)
|
||||
# update valid prefill batch
|
||||
new_prefill_querys_info = []
|
||||
for info in prefill_querys_info:
|
||||
if info is not None:
|
||||
new_prefill_querys_info.append(info)
|
||||
batch_prefill = len(new_prefill_querys_info)
|
||||
|
||||
self.num_tokens = batch_decode * decode_padding_len + sum(prefill_l)
|
||||
if prefill_s is None or prefill_l is None:
|
||||
raise ValueError(
|
||||
"[ForwardMiniBatchSplit.__init__] prefill_s / prefill_l 不能为空,chunk prefill 需要这两个参数"
|
||||
)
|
||||
|
||||
# 过滤掉 None
|
||||
new_prefill_querys_info: list[QueryInfo] = [
|
||||
info for info in prefill_querys_info if info is not None
|
||||
]
|
||||
batch_prefill = len(new_prefill_querys_info)
|
||||
batch_decode = len(decode_querys_info)
|
||||
|
||||
self.prefill_batch = batch_prefill
|
||||
self.decode_batch = batch_decode
|
||||
self.batch_size = batch_decode + batch_prefill
|
||||
self.batch_size = batch_prefill + batch_decode
|
||||
self.num_tokens = batch_decode * decode_padding_len + sum(prefill_l)
|
||||
|
||||
self.chunk_size = prefill_l[0] if prefill_l else 0
|
||||
|
||||
self.is_last_prefill_chunk = True
|
||||
for i, q in enumerate(new_prefill_querys_info):
|
||||
end_pos = prefill_s[i] + prefill_l[i]
|
||||
if end_pos < q.query_length:
|
||||
self.is_last_prefill_chunk = False
|
||||
break
|
||||
|
||||
# ====================== Prefill 部分 ======================
|
||||
self.p_q_len = torch.tensor([], device=device, dtype=torch.int32)
|
||||
self.p_kv_len = torch.tensor([], device=device, dtype=torch.int32)
|
||||
self.p_position_ids = torch.tensor([], device=device, dtype=torch.int32)
|
||||
self.p_block_tables = -1 * torch.ones([self.prefill_batch, max_page_num], device=device, dtype=torch.int32)
|
||||
# self.p_kv_page_offset = torch.tensor([], device=device, dtype=torch.int32)
|
||||
self.p_block_tables = -1 * torch.ones(
|
||||
[self.prefill_batch, max_page_num], device=device, dtype=torch.int32
|
||||
)
|
||||
self.p_tokens = torch.tensor([], device=device, dtype=torch.int32)
|
||||
|
||||
self.p_temperatures = torch.tensor([], device=device, dtype=torch.float32)
|
||||
self.p_top_ps = torch.tensor([], device=device, dtype=torch.float32)
|
||||
self.p_logits_start = []
|
||||
self.p_logits_start: list[int] = []
|
||||
|
||||
for i, prefill_query_info in enumerate(new_prefill_querys_info):
|
||||
prefill_kv_block_len = (prefill_query_info.active_position + prefill_l[i] + page_size - 1) // page_size if prefill_query_info is not None else 0
|
||||
assert prefill_query_info.active_position == 0, '[ERROR] currently do not support prefix cache or chunk prefill in balance serving!'
|
||||
# print(f"block_len: {prefill_kv_block_len}, page_size: {page_size}")
|
||||
self.p_q_len = torch.concat((self.p_q_len, torch.tensor([prefill_l[i]], device=device, dtype=torch.int32)), dim=0)
|
||||
self.p_kv_len = torch.concat((self.p_kv_len, torch.tensor([prefill_query_info.active_position + prefill_l[i]], device=device, dtype=torch.int32)), dim=0)
|
||||
self.p_block_tables[i, :prefill_kv_block_len] = prefill_query_info.block_index[:prefill_kv_block_len]
|
||||
# self.p_kv_page_offset = torch.concat((self.p_kv_page_offset, torch.tensor([(prefill_query_info.active_position + prefill_l[i]) % page_size if (prefill_query_info.active_position + prefill_l[i]) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)
|
||||
self.p_position_ids = torch.concat((self.p_position_ids, torch.arange(prefill_s[i], prefill_l[i] + prefill_s[i], device=device, dtype=torch.int32)), dim=0)
|
||||
self.p_tokens = torch.concat((self.p_tokens, prefill_query_info.query_tokens[prefill_s[i]:prefill_s[i] + prefill_l[i]]), dim=0)
|
||||
self.p_logits_start.append(prefill_l[i] - 1 if len(self.p_logits_start) == 0 else sum(prefill_l[:i+1])-1)
|
||||
qid = getattr(prefill_query_info, "id", -1)
|
||||
|
||||
self.p_temperatures = torch.concat((self.p_temperatures, torch.tensor([prefill_query_info.temperature], device=device, dtype=torch.float32)), dim=0)
|
||||
self.p_top_ps = torch.concat((self.p_top_ps, torch.tensor([prefill_query_info.top_p], device=device, dtype=torch.float32)), dim=0)
|
||||
past_len = int(prefill_query_info.active_position)
|
||||
start = int(prefill_s[i]) # current chunk's start position in query_tokens
|
||||
chunk_len = int(prefill_l[i])
|
||||
kv_len = past_len + chunk_len
|
||||
prefill_kv_block_len = (kv_len + page_size - 1) // page_size
|
||||
|
||||
# Q length = current chunk length
|
||||
self.p_q_len = torch.concat(
|
||||
(
|
||||
self.p_q_len,
|
||||
torch.tensor([chunk_len], device=device, dtype=torch.int32),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
self.p_kv_len = torch.concat(
|
||||
(
|
||||
self.p_kv_len,
|
||||
torch.tensor([kv_len], device=device, dtype=torch.int32),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
self.p_block_tables[i, :prefill_kv_block_len] = prefill_query_info.block_index[
|
||||
:prefill_kv_block_len
|
||||
]
|
||||
|
||||
self.p_position_ids = torch.concat(
|
||||
(
|
||||
self.p_position_ids,
|
||||
torch.arange(
|
||||
start,
|
||||
start + chunk_len,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
self.p_tokens = torch.concat(
|
||||
(
|
||||
self.p_tokens,
|
||||
prefill_query_info.query_tokens[start : start + chunk_len],
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
self.p_logits_start.append(
|
||||
chunk_len - 1
|
||||
if len(self.p_logits_start) == 0
|
||||
else sum(prefill_l[: i + 1]) - 1
|
||||
)
|
||||
|
||||
self.p_temperatures = torch.concat(
|
||||
(
|
||||
self.p_temperatures,
|
||||
torch.tensor(
|
||||
[prefill_query_info.temperature],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
self.p_top_ps = torch.concat(
|
||||
(
|
||||
self.p_top_ps,
|
||||
torch.tensor(
|
||||
[prefill_query_info.top_p],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# ====================== Decode ======================
|
||||
self.d_q_len = torch.tensor([], device=device, dtype=torch.int32)
|
||||
self.d_kv_len = torch.tensor([], device=device, dtype=torch.int32)
|
||||
self.d_position_ids = torch.tensor([], device=device, dtype=torch.int32)
|
||||
self.d_block_tables = -1 * torch.ones([self.decode_batch, max_page_num], device=device, dtype=torch.int32)
|
||||
# self.p_kv_page_offset = torch.tensor([], device=device, dtype=torch.int32)
|
||||
self.d_block_tables = -1 * torch.ones(
|
||||
[self.decode_batch, max_page_num], device=device, dtype=torch.int32
|
||||
)
|
||||
self.d_tokens = torch.tensor([], device=device, dtype=torch.int32)
|
||||
|
||||
self.d_temperatures = torch.tensor([], device=device, dtype=torch.float32)
|
||||
self.d_top_ps = torch.tensor([], device=device, dtype=torch.float32)
|
||||
self.d_logits_start = []
|
||||
self.d_logits_start: list[int] = []
|
||||
|
||||
# 1 2 ...
|
||||
# 1
|
||||
# postion
|
||||
# page table
|
||||
for i, decode_query_info in enumerate(decode_querys_info):
|
||||
# print("decode_query_info.active_position is ", decode_query_info.active_position)
|
||||
qid = getattr(decode_query_info, "id", -1)
|
||||
past_len = int(decode_query_info.active_position)
|
||||
decode_kv_block_len = (past_len + decode_padding_len + page_size - 1) // page_size
|
||||
|
||||
decode_kv_block_len = (decode_query_info.active_position + decode_padding_len + page_size - 1) // page_size
|
||||
self.d_q_len = torch.concat((self.d_q_len, torch.tensor([decode_padding_len], device=device, dtype=torch.int32)), dim=0)
|
||||
self.d_kv_len = torch.concat((self.d_kv_len, torch.tensor([decode_query_info.active_position + decode_padding_len], device=device, dtype=torch.int32)), dim=0)
|
||||
self.d_block_tables[i, :decode_kv_block_len] = decode_query_info.block_index[:decode_kv_block_len]
|
||||
# self.d_kv_page_offset = torch.concat((self.d_kv_page_offset, torch.tensor([(decode_query_info.active_position + decode_padding_len) % page_size if (decode_query_info.active_position + 1) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)
|
||||
self.d_position_ids = torch.concat((self.d_position_ids, torch.arange(decode_query_info.active_position, decode_query_info.active_position + decode_padding_len, device=device, dtype=torch.int32)), dim=0)
|
||||
if decode_query_info.active_position > 0:
|
||||
self.d_tokens = torch.concat((self.d_tokens, decode_query_info.query_tokens[decode_query_info.active_position:decode_query_info.active_position+decode_padding_len]), dim=0)
|
||||
self.d_q_len = torch.concat(
|
||||
(
|
||||
self.d_q_len,
|
||||
torch.tensor(
|
||||
[decode_padding_len], device=device, dtype=torch.int32
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
self.d_kv_len = torch.concat(
|
||||
(
|
||||
self.d_kv_len,
|
||||
torch.tensor(
|
||||
[past_len + decode_padding_len],
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
self.d_block_tables[i, :decode_kv_block_len] = decode_query_info.block_index[
|
||||
:decode_kv_block_len
|
||||
]
|
||||
|
||||
self.d_position_ids = torch.concat(
|
||||
(
|
||||
self.d_position_ids,
|
||||
torch.arange(
|
||||
past_len,
|
||||
past_len + decode_padding_len,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if past_len > 0:
|
||||
self.d_tokens = torch.concat(
|
||||
(
|
||||
self.d_tokens,
|
||||
decode_query_info.query_tokens[
|
||||
past_len : past_len + decode_padding_len
|
||||
],
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
else:
|
||||
self.d_tokens = torch.concat((self.d_tokens, torch.tensor([0] * decode_padding_len, device=device, dtype=torch.int32)), dim=0)
|
||||
self.d_logits_start.append(0 if len(self.d_logits_start) == 0 else self.d_logits_start[-1]+decode_padding_len)
|
||||
# print("self.d_position_ids is ", self.d_position_ids)
|
||||
self.d_tokens = torch.concat(
|
||||
(
|
||||
self.d_tokens,
|
||||
torch.tensor(
|
||||
[0] * decode_padding_len,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
self.d_temperatures = torch.concat((self.d_temperatures, torch.tensor([decode_query_info.temperature], device=device, dtype=torch.float32)), dim=0)
|
||||
self.d_top_ps = torch.concat((self.d_top_ps, torch.tensor([decode_query_info.top_p], device=device, dtype=torch.float32)), dim=0)
|
||||
self.d_logits_start.append(
|
||||
0
|
||||
if len(self.d_logits_start) == 0
|
||||
else self.d_logits_start[-1] + decode_padding_len
|
||||
)
|
||||
|
||||
self.d_temperatures = torch.concat(
|
||||
(
|
||||
self.d_temperatures,
|
||||
torch.tensor(
|
||||
[decode_query_info.temperature],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
self.d_top_ps = torch.concat(
|
||||
(
|
||||
self.d_top_ps,
|
||||
torch.tensor(
|
||||
[decode_query_info.top_p],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
self.p_q_len = self.p_q_len.contiguous()
|
||||
self.p_kv_len = self.p_kv_len.contiguous()
|
||||
|
|
@ -291,7 +453,6 @@ class ForwardMiniBatchSplit:
|
|||
self.d_position_ids = self.d_position_ids.reshape(self.decode_batch, -1).contiguous()
|
||||
self.d_tokens = self.d_tokens.reshape(self.decode_batch, -1).contiguous()
|
||||
else:
|
||||
# TODO remove this
|
||||
self.d_q_len = self.d_q_len.contiguous()
|
||||
self.d_kv_len = self.d_kv_len.contiguous()
|
||||
self.d_kv_len_list = self.d_kv_len.flatten().tolist()
|
||||
|
|
@ -301,30 +462,53 @@ class ForwardMiniBatchSplit:
|
|||
|
||||
self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32)
|
||||
|
||||
def fill(self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo], prefill_s: list[int] = None, prefill_l: list[int] = None, decode_padding_len=1, device = None, page_size = 256, max_page_num=64):
|
||||
device = torch.device('npu')
|
||||
|
||||
page_size = 128
|
||||
|
||||
batch_decode = len(decode_querys_info)
|
||||
# batch_prefill = len(prefill_querys_info)
|
||||
# update valid prefill batch
|
||||
new_prefill_querys_info = []
|
||||
for info in prefill_querys_info:
|
||||
if info is not None:
|
||||
new_prefill_querys_info.append(info)
|
||||
batch_prefill = len(new_prefill_querys_info)
|
||||
|
||||
self.num_tokens = batch_decode + sum(prefill_l)
|
||||
def fill(
|
||||
self,
|
||||
prefill_querys_info: list[QueryInfo],
|
||||
decode_querys_info: list[QueryInfo],
|
||||
prefill_s: list[int] = None,
|
||||
prefill_l: list[int] = None,
|
||||
decode_padding_len: int = 1,
|
||||
device=None,
|
||||
page_size: int = 256,
|
||||
max_page_num: int = 64,
|
||||
):
|
||||
device = torch.device('npu')
|
||||
|
||||
if prefill_s is None or prefill_l is None:
|
||||
raise ValueError(
|
||||
"[ForwardMiniBatchSplit.fill] prefill_s / prefill_l 不能为空,chunk prefill 需要这两个参数"
|
||||
)
|
||||
|
||||
page_size = 128
|
||||
|
||||
new_prefill_querys_info: list[QueryInfo] = [
|
||||
info for info in prefill_querys_info if info is not None
|
||||
]
|
||||
batch_prefill = len(new_prefill_querys_info)
|
||||
batch_decode = len(decode_querys_info)
|
||||
|
||||
self.prefill_batch = batch_prefill
|
||||
self.decode_batch = batch_decode
|
||||
self.batch_size = batch_decode + batch_prefill
|
||||
self.batch_size = batch_prefill + batch_decode
|
||||
self.num_tokens = batch_decode * decode_padding_len + sum(prefill_l)
|
||||
|
||||
self.chunk_size = prefill_l[0] if prefill_l else 0
|
||||
self.is_last_prefill_chunk = True
|
||||
for i, q in enumerate(new_prefill_querys_info):
|
||||
end_pos = prefill_s[i] + prefill_l[i]
|
||||
if end_pos < q.query_length:
|
||||
self.is_last_prefill_chunk = False
|
||||
break
|
||||
|
||||
# ---------- Prefill ----------
|
||||
self.p_q_len = torch.tensor([], device=device, dtype=torch.int32)
|
||||
self.p_kv_len = torch.tensor([], device=device, dtype=torch.int32)
|
||||
new_p_position_ids = torch.tensor([], device=device, dtype=torch.int32)
|
||||
self.p_block_tables = torch.zeros([self.prefill_batch, max_page_num], device=device, dtype=torch.int32)
|
||||
# self.p_kv_page_offset = torch.tensor([], device=device, dtype=torch.int32)
|
||||
self.p_block_tables = torch.zeros(
|
||||
[self.prefill_batch, max_page_num], device=device, dtype=torch.int32
|
||||
)
|
||||
new_p_tokens = torch.tensor([], device=device, dtype=torch.int32)
|
||||
|
||||
self.p_temperatures = torch.tensor([], device=device, dtype=torch.float32)
|
||||
|
|
@ -332,58 +516,190 @@ class ForwardMiniBatchSplit:
|
|||
self.p_logits_start = []
|
||||
|
||||
for i, prefill_query_info in enumerate(new_prefill_querys_info):
|
||||
prefill_kv_block_len = (prefill_query_info.active_position + prefill_l[i] + page_size - 1) // page_size if prefill_query_info is not None else 0
|
||||
assert prefill_query_info.active_position == 0, '[ERROR] currently do not support prefix cache or chunk prefill in balance serving!'
|
||||
# print(f"block_len: {prefill_kv_block_len}, page_size: {page_size}")
|
||||
self.p_q_len = torch.concat((self.p_q_len, torch.tensor([prefill_l[i]], device=device, dtype=torch.int32)), dim=0)
|
||||
self.p_kv_len = torch.concat((self.p_kv_len, torch.tensor([prefill_query_info.active_position + prefill_l[i]], device=device, dtype=torch.int32)), dim=0)
|
||||
self.p_block_tables[i, :prefill_kv_block_len] = prefill_query_info.block_index[:prefill_kv_block_len]
|
||||
# self.p_kv_page_offset = torch.concat((self.p_kv_page_offset, torch.tensor([(prefill_query_info.active_position + prefill_l[i]) % page_size if (prefill_query_info.active_position + prefill_l[i]) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)
|
||||
new_p_position_ids = torch.concat((new_p_position_ids, torch.arange(prefill_s[i], prefill_l[i] + prefill_s[i], device=device, dtype=torch.int32)), dim=0)
|
||||
new_p_tokens = torch.concat((new_p_tokens, prefill_query_info.query_tokens[prefill_s[i]:prefill_s[i] + prefill_l[i]]), dim=0)
|
||||
self.p_logits_start.append(prefill_l[i] - 1 if len(self.p_logits_start) == 0 else sum(prefill_l[:i+1])-1)
|
||||
qid = getattr(prefill_query_info, "id", -1)
|
||||
past_len = int(prefill_query_info.active_position)
|
||||
start = int(prefill_s[i])
|
||||
chunk_len = int(prefill_l[i])
|
||||
|
||||
self.p_temperatures = torch.concat((self.p_temperatures, torch.tensor([prefill_query_info.temperature], device=device, dtype=torch.float32)), dim=0)
|
||||
self.p_top_ps = torch.concat((self.p_top_ps, torch.tensor([prefill_query_info.top_p], device=device, dtype=torch.float32)), dim=0)
|
||||
kv_len = past_len + chunk_len
|
||||
prefill_kv_block_len = (kv_len + page_size - 1) // page_size
|
||||
|
||||
self.d_q_len = torch.zeros([1] * self.decode_batch, device=device, dtype=torch.int32)
|
||||
self.p_q_len = torch.concat(
|
||||
(
|
||||
self.p_q_len,
|
||||
torch.tensor([chunk_len], device=device, dtype=torch.int32),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
self.p_kv_len = torch.concat(
|
||||
(
|
||||
self.p_kv_len,
|
||||
torch.tensor([kv_len], device=device, dtype=torch.int32),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
self.p_block_tables[i, :prefill_kv_block_len] = prefill_query_info.block_index[
|
||||
:prefill_kv_block_len
|
||||
]
|
||||
|
||||
new_p_position_ids = torch.concat(
|
||||
(
|
||||
new_p_position_ids,
|
||||
torch.arange(
|
||||
start,
|
||||
start + chunk_len,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
new_p_tokens = torch.concat(
|
||||
(
|
||||
new_p_tokens,
|
||||
prefill_query_info.query_tokens[start : start + chunk_len],
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
self.p_logits_start.append(
|
||||
chunk_len - 1 if len(self.p_logits_start) == 0 else sum(prefill_l[: i + 1]) - 1
|
||||
)
|
||||
|
||||
self.p_temperatures = torch.concat(
|
||||
(
|
||||
self.p_temperatures,
|
||||
torch.tensor(
|
||||
[prefill_query_info.temperature],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
self.p_top_ps = torch.concat(
|
||||
(
|
||||
self.p_top_ps,
|
||||
torch.tensor(
|
||||
[prefill_query_info.top_p],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if new_p_position_ids.numel() > 0:
|
||||
self.p_position_ids = new_p_position_ids.contiguous()
|
||||
if new_p_tokens.numel() > 0:
|
||||
self.p_tokens = new_p_tokens.contiguous()
|
||||
|
||||
# ---------- Decode ----------
|
||||
self.d_q_len = torch.zeros(
|
||||
[1] * self.decode_batch, device=device, dtype=torch.int32
|
||||
)
|
||||
self.d_kv_len = torch.tensor([], device=device, dtype=torch.int32)
|
||||
new_d_position_ids = torch.tensor([], device=device, dtype=torch.int32)
|
||||
new_d_block_tables = -1 * torch.ones([self.decode_batch, max_page_num], device=device, dtype=torch.int32)
|
||||
# self.p_kv_page_offset = torch.tensor([], device=device, dtype=torch.int32)
|
||||
new_d_block_tables = -1 * torch.ones(
|
||||
[self.decode_batch, max_page_num], device=device, dtype=torch.int32
|
||||
)
|
||||
new_d_tokens = torch.tensor([], device=device, dtype=torch.int32)
|
||||
self.d_logits_start = []
|
||||
|
||||
self.d_logits_start = []
|
||||
self.d_temperatures = torch.tensor([], device=device, dtype=torch.float32)
|
||||
self.d_top_ps = torch.tensor([], device=device, dtype=torch.float32)
|
||||
|
||||
for i, decode_query_info in enumerate(decode_querys_info):
|
||||
decode_kv_block_len = (decode_query_info.active_position + 1 + page_size - 1) // page_size
|
||||
self.d_kv_len = torch.concat((self.d_kv_len, torch.tensor([decode_query_info.active_position + 1], device=device, dtype=torch.int32)), dim=0)
|
||||
# print("fill self.d_block_tables is ", self.d_block_tables)
|
||||
new_d_block_tables[i, :decode_kv_block_len] = decode_query_info.block_index[:decode_kv_block_len]
|
||||
# print("decode_query_info.block_index[:decode_kv_block_len] is ", decode_query_info.block_index[:decode_kv_block_len])
|
||||
# self.d_kv_page_offset = torch.concat((self.d_kv_page_offset, torch.tensor([(decode_query_info.active_position + 1) % page_size if (decode_query_info.active_position + 1) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)
|
||||
new_d_position_ids = torch.concat((new_d_position_ids, torch.arange(decode_query_info.active_position, decode_query_info.active_position + 1, device=device, dtype=torch.int32)), dim=0)
|
||||
# print("decode_query_info.active_position is ", decode_query_info.active_position)
|
||||
qid = getattr(decode_query_info, "id", -1)
|
||||
past_len = int(decode_query_info.active_position)
|
||||
decode_kv_block_len = (past_len + decode_padding_len + page_size - 1) // page_size
|
||||
|
||||
if decode_query_info.active_position > 0:
|
||||
new_d_tokens = torch.concat((new_d_tokens, decode_query_info.query_tokens[decode_query_info.active_position:decode_query_info.active_position+1]), dim=0)
|
||||
self.d_kv_len = torch.concat(
|
||||
(
|
||||
self.d_kv_len,
|
||||
torch.tensor(
|
||||
[past_len + decode_padding_len],
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
new_d_block_tables[i, :decode_kv_block_len] = decode_query_info.block_index[
|
||||
:decode_kv_block_len
|
||||
]
|
||||
|
||||
new_d_position_ids = torch.concat(
|
||||
(
|
||||
new_d_position_ids,
|
||||
torch.arange(
|
||||
past_len,
|
||||
past_len + decode_padding_len,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if past_len > 0:
|
||||
new_d_tokens = torch.concat(
|
||||
(
|
||||
new_d_tokens,
|
||||
decode_query_info.query_tokens[
|
||||
past_len : past_len + decode_padding_len
|
||||
],
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
else:
|
||||
new_d_tokens = torch.concat((new_d_tokens, torch.tensor([0], device=device, dtype=torch.int32)), dim=0)
|
||||
self.d_logits_start.append(0 if len(self.d_logits_start) == 0 else self.d_logits_start[-1]+1)
|
||||
new_d_tokens = torch.concat(
|
||||
(
|
||||
new_d_tokens,
|
||||
torch.tensor(
|
||||
[0] * decode_padding_len,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
self.d_temperatures = torch.concat((self.d_temperatures, torch.tensor([decode_query_info.temperature], device=device, dtype=torch.float32)), dim=0)
|
||||
self.d_top_ps = torch.concat((self.d_top_ps, torch.tensor([decode_query_info.top_p], device=device, dtype=torch.float32)), dim=0)
|
||||
self.d_logits_start.append(
|
||||
0
|
||||
if len(self.d_logits_start) == 0
|
||||
else self.d_logits_start[-1] + decode_padding_len
|
||||
)
|
||||
|
||||
self.d_temperatures = torch.concat(
|
||||
(
|
||||
self.d_temperatures,
|
||||
torch.tensor(
|
||||
[decode_query_info.temperature],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
self.d_top_ps = torch.concat(
|
||||
(
|
||||
self.d_top_ps,
|
||||
torch.tensor(
|
||||
[decode_query_info.top_p],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if len(decode_querys_info) > 1:
|
||||
self.d_position_ids[i].copy_(new_d_position_ids[i])
|
||||
# self.d_position_ids[i:].zero_()
|
||||
self.d_tokens[i].copy_(new_d_tokens[i])
|
||||
self.d_block_tables[i].copy_(new_d_block_tables[i])
|
||||
else:
|
||||
self.d_position_ids[:new_d_position_ids.size(0)].copy_(new_d_position_ids)
|
||||
# self.d_position_ids[new_d_position_ids.size(0):].zero_()
|
||||
self.d_tokens[:new_d_tokens.size(0)].copy_(new_d_tokens)
|
||||
self.d_block_tables[0].copy_(new_d_block_tables[0])
|
||||
|
||||
|
|
@ -391,15 +707,10 @@ class ForwardMiniBatchSplit:
|
|||
self.p_q_len = self.p_q_len.contiguous()
|
||||
self.p_kv_len = self.p_kv_len.contiguous()
|
||||
self.p_block_tables = self.p_block_tables.contiguous()
|
||||
# self.p_position_ids = self.p_position_ids.contiguous()
|
||||
# self.p_tokens = self.p_tokens.contiguous()
|
||||
|
||||
self.d_q_len = self.d_q_len.contiguous()
|
||||
self.d_kv_len = self.d_kv_len.contiguous()
|
||||
self.d_kv_len_list = self.d_kv_len.flatten().tolist()
|
||||
# self.d_block_tables = self.d_block_tables.contiguous()
|
||||
# self.d_position_ids = self.d_position_ids.contiguous()
|
||||
# self.d_tokens = self.d_tokens.contiguous()
|
||||
|
||||
self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32)
|
||||
|
||||
|
|
@ -407,12 +718,13 @@ class ForwardMiniBatchSplit:
|
|||
|
||||
def __str__(self):
|
||||
ret = ''
|
||||
ret += f'=======Prefill forward info:\n'
|
||||
ret += f'batch: {self.prefill_batch=}, qLen: {self.p_q_len=}, kvLen: {self.p_kv_len=}\n'
|
||||
ret += f'tokens: {self.p_tokens=}, posIdx: {self.p_position_ids=}, block_tables: {self.p_block_tables=}\n'
|
||||
ret += f'=======Decode forward info:\n'
|
||||
ret += f'batch: {self.decode_batch=}, qLen: {self.d_q_len=}, kvLen: {self.d_kv_len=}\n'
|
||||
ret += f'tokens: {self.d_tokens=}, posIdx: {self.d_position_ids=}, block_tables: {self.d_block_tables=}\n'
|
||||
ret += '=======Prefill forward info:\n'
|
||||
ret += f'batch: {self.prefill_batch}, qLen: {self.p_q_len}, kvLen: {self.p_kv_len}\n'
|
||||
ret += f'tokens: {self.p_tokens}, posIdx: {self.p_position_ids}, block_tables: {self.p_block_tables}\n'
|
||||
ret += '=======Decode forward info:\n'
|
||||
ret += f'batch: {self.decode_batch}, qLen: {self.d_q_len}, kvLen: {self.d_kv_len}\n'
|
||||
ret += f'tokens: {self.d_tokens}, posIdx: {self.d_position_ids}, block_tables: {self.d_block_tables}\n'
|
||||
ret += f'chunk_size={self.chunk_size}, is_last_prefill_chunk={self.is_last_prefill_chunk}\n'
|
||||
return ret
|
||||
|
||||
|
||||
|
|
@ -437,14 +749,15 @@ class ForwardBatchInput:
|
|||
prefill_l = []
|
||||
decode_querys_info = []
|
||||
self.batch_size = 1
|
||||
for (id, s, l) in prefill_minibatches:
|
||||
prefill_querys_info.append(query_manager.query_map[id])
|
||||
for (qid, s, l) in prefill_minibatches:
|
||||
prefill_querys_info.append(query_manager.query_map[qid])
|
||||
prefill_s.append(s)
|
||||
prefill_l.append(l)
|
||||
for decode_batch_idx in decode_mini_batches:
|
||||
if query_manager.query_map[decode_batch_idx].decode_start_time is None:
|
||||
query_manager.query_map[decode_batch_idx].decode_start_time =time.time()
|
||||
decode_querys_info.append(query_manager.query_map[decode_batch_idx])
|
||||
for decode_qid in decode_mini_batches:
|
||||
qinfo = query_manager.query_map[decode_qid]
|
||||
if qinfo.decode_start_time is None:
|
||||
qinfo.decode_start_time = time.time()
|
||||
decode_querys_info.append(qinfo)
|
||||
|
||||
if use_torch_npu:
|
||||
minibatch = ForwardMiniBatchSplit(prefill_querys_info, decode_querys_info, prefill_s, prefill_l, device = query_manager.device, page_size = query_manager.page_size)
|
||||
|
|
@ -493,7 +806,7 @@ class ForwardBatchInput:
|
|||
|
||||
decode_querys_info.append(query_info)
|
||||
|
||||
if prefill_query_length*Config().max_prefill_batch_size + len(decode_querys_info) < cuda_lens:
|
||||
if prefill_query_length * Config().max_prefill_batch_size + len(decode_querys_info) < cuda_lens:
|
||||
decode_querys_info.append(query_info)
|
||||
if use_torch_npu:
|
||||
instance.minibatch = ForwardMiniBatchSplit(prefill_query_info, decode_querys_info, [0, 0],
|
||||
|
|
|
|||
|
|
@ -44,7 +44,8 @@ try:
|
|||
import torch_npu
|
||||
use_torch_npu = torch_npu.npu.is_available()
|
||||
from ktransformers.models.ascend.custom_ascend_modeling_deepseek_v3 import KNPUDeepseekV3ForCausalLM
|
||||
from ktransformers.models.custom_cache import KVC2StaticCache
|
||||
from ktransformers.models.ascend.custom_ascend_modeling_qwen3 import KNPUQwen3MoeForCausalLM
|
||||
from ktransformers.models.custom_cache import KVC2StaticCache, KVC2Qwen3Cache
|
||||
except:
|
||||
use_torch_npu = False
|
||||
|
||||
|
|
@ -70,8 +71,8 @@ class ModelRunner:
|
|||
if not use_torch_npu:
|
||||
model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM | KSmallThinkerForCausalLM | KGlm4MoeForCausalLM | KQwen3NextForCausalLM
|
||||
else:
|
||||
model: KNPUDeepseekV3ForCausalLM
|
||||
cache: KVC2StaticCache #TODO 只有npu适配的代码里用到,规避
|
||||
model: KNPUDeepseekV3ForCausalLM | KNPUQwen3MoeForCausalLM
|
||||
cache: KVC2StaticCache | KVC2Qwen3Cache
|
||||
input: ForwardBatchInput | list[ForwardBatchInput]
|
||||
output: ForwardBatchOutput
|
||||
|
||||
|
|
@ -210,7 +211,17 @@ class ModelRunner:
|
|||
utils._USE_NPU_GRAPH = True
|
||||
print("self.features_buf[npu_graph_idx] is ", self.features_buf[npu_graph_idx])
|
||||
with torch.npu.graph(self.graphs[npu_graph_idx], pool=self.graph_memory_pool, stream=self.stream, auto_dispatch_capture=True):
|
||||
self.outputs_buf[npu_graph_idx] = self.model(self.input_decode[npu_graph_idx], self.features_buf[npu_graph_idx], self.cache, None, None, self.page_idx_buf[npu_graph_idx], self.page_offset_buf[npu_graph_idx], self.position_ids_buf[npu_graph_idx], self.block_tables_buf[npu_graph_idx], cuda_graph_idx=npu_graph_idx, is_prefill=False)
|
||||
self.outputs_buf[npu_graph_idx] = self.model(
|
||||
self.input_decode[npu_graph_idx],
|
||||
self.features_buf[npu_graph_idx],
|
||||
self.cache, None, None,
|
||||
self.page_idx_buf[npu_graph_idx],
|
||||
self.page_offset_buf[npu_graph_idx],
|
||||
self.position_ids_buf[npu_graph_idx],
|
||||
self.block_tables_buf[npu_graph_idx],
|
||||
cuda_graph_idx=npu_graph_idx,
|
||||
is_prefill=False
|
||||
)
|
||||
self.graph_memory_pool = self.graphs[npu_graph_idx].pool()
|
||||
utils._USE_NPU_GRAPH = False
|
||||
|
||||
|
|
@ -340,7 +351,6 @@ class ModelRunner:
|
|||
def _run_infer_stage(is_prefill=True):
|
||||
if "npu" in self.device:
|
||||
cuda_graph_idx = batch_size_decode
|
||||
# print("batch_size is ", batch_size)
|
||||
if is_prefill == False:
|
||||
if cuda_graph_idx != -1 and self.use_cuda_graph:
|
||||
self.features = self.model.batch_embeddings(self.input_decode[cuda_graph_idx], device=self.device, is_prefill=is_prefill)
|
||||
|
|
@ -370,7 +380,6 @@ class ModelRunner:
|
|||
|
||||
self.replay(cuda_graph_idx)
|
||||
new_output = ForwardBatchOutput()
|
||||
# bsz = self.outputs_buf[cuda_graph_idx].logits[0][self.input_decode[cuda_graph_idx].minibatch.d_logits_start].size(0)
|
||||
for i in range(num_tokens):
|
||||
new_output.top_ps.append(self.input_decode[cuda_graph_idx].minibatch.d_top_ps[i])
|
||||
new_output.temperatures.append(self.input_decode[cuda_graph_idx].minibatch.d_temperatures[i])
|
||||
|
|
@ -389,13 +398,11 @@ class ModelRunner:
|
|||
bsz = len(new_output.logits)
|
||||
if is_prefill:
|
||||
for i in range(bsz):
|
||||
# new_output.logits[i] = new_output.logits[i][self.input.minibatch.p_logits_start[i]:, :] # slice prefill seq[-1]
|
||||
new_output.logits[i] = new_output.logits[i][-1:, :] # batched tensor do not need location
|
||||
new_output.top_ps.append(self.input.minibatch.p_top_ps[i])
|
||||
new_output.temperatures.append(self.input.minibatch.p_temperatures[i])
|
||||
else:
|
||||
for i in range(bsz):
|
||||
# new_output.logits[i] = new_output.logits[i][self.input.minibatch.d_logits_start[i]:, :]
|
||||
new_output.top_ps.append(self.input.minibatch.d_top_ps[i])
|
||||
new_output.temperatures.append(self.input.minibatch.d_temperatures[i])
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
torch >= 2.3.0
|
||||
transformers == 4.51.3
|
||||
transformers >= 4.51.3
|
||||
fastapi >= 0.111.0
|
||||
langchain >= 0.2.0
|
||||
blessed >= 1.20.0
|
||||
|
|
|
|||
|
|
@ -1,3 +1,19 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved.
|
||||
# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from datetime import timedelta
|
||||
|
||||
|
|
@ -148,24 +164,35 @@ def get_safetensors_cut_weight(name: str, weights: torch.Tensor):
|
|||
|
||||
|
||||
def get_absort_weight(model, config):
|
||||
# 新增q_absorb, out_absorb属性
|
||||
local_rank = torch.distributed.get_rank()
|
||||
if not dist.is_initialized():
|
||||
return
|
||||
local_rank = dist.get_rank()
|
||||
tp = get_tensor_parallel_size()
|
||||
local_rank %= tp
|
||||
tp_heads = config.num_attention_heads // tp
|
||||
for i in range(config.num_hidden_layers):
|
||||
self = model.model.layers[i].self_attn
|
||||
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
|
||||
kv_b_proj = self.kv_b_proj.weight.view(config.num_attention_heads, -1, self.kv_lora_rank)
|
||||
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].clone()
|
||||
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].clone()
|
||||
q_absorb = q_absorb[local_rank * tp_heads: (local_rank + 1) * tp_heads, :, :].contiguous()
|
||||
out_absorb = out_absorb[local_rank * tp_heads: (local_rank + 1) * tp_heads, :, :].contiguous()
|
||||
out_absorb = out_absorb.transpose(1, 2).contiguous()
|
||||
setattr(self, "q_absorb", q_absorb)
|
||||
setattr(self, "out_absorb", out_absorb)
|
||||
del self.orig_module.kv_b_proj
|
||||
torch.distributed.barrier(get_tensor_parallel_group())
|
||||
attn = model.model.layers[i].self_attn
|
||||
if hasattr(attn, "q_absorb") and hasattr(attn, "out_absorb"):
|
||||
continue
|
||||
if not (hasattr(attn, "kv_b_proj")
|
||||
and hasattr(attn, "kv_lora_rank")
|
||||
and hasattr(attn, "qk_nope_head_dim")):
|
||||
continue
|
||||
|
||||
kv_b_proj = attn.kv_b_proj.weight.view(config.num_attention_heads, -1, attn.kv_lora_rank)
|
||||
q_absorb = kv_b_proj[:, :attn.qk_nope_head_dim, :].clone()
|
||||
out_absorb = kv_b_proj[:, attn.qk_nope_head_dim:, :].clone()
|
||||
|
||||
q_absorb = q_absorb[local_rank * tp_heads: (local_rank + 1) * tp_heads, :, :].contiguous()
|
||||
out_absorb = out_absorb[local_rank * tp_heads: (local_rank + 1) * tp_heads, :, :].contiguous()
|
||||
out_absorb = out_absorb.transpose(1, 2).contiguous()
|
||||
|
||||
setattr(attn, "q_absorb", q_absorb)
|
||||
setattr(attn, "out_absorb", out_absorb)
|
||||
|
||||
if hasattr(attn, "orig_module") and hasattr(attn.orig_module, "kv_b_proj"):
|
||||
del attn.orig_module.kv_b_proj
|
||||
dist.barrier(get_tensor_parallel_group())
|
||||
|
||||
|
||||
def allredeuce_warpper(func):
|
||||
|
|
|
|||
|
|
@ -666,6 +666,33 @@ def translate_name_to_gguf(name):
|
|||
|
||||
name = translate_name_to_gguf_mixtral(name)
|
||||
|
||||
if ".ffn_gate_exp." in name:
|
||||
name = name.replace(".ffn_gate_exp.", ".ffn_gate_exps.")
|
||||
if ".ffn_up_exp." in name:
|
||||
name = name.replace(".ffn_up_exp.", ".ffn_up_exps.")
|
||||
if ".ffn_down_exp." in name:
|
||||
name = name.replace(".ffn_down_exp.", ".ffn_down_exps.")
|
||||
|
||||
m = re.match(r"model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)", name)
|
||||
if m:
|
||||
layer, expert, proj = m.groups()
|
||||
if proj == "gate_proj":
|
||||
return f"blk.{layer}.{expert}.ffn_gate_exps"
|
||||
elif proj == "up_proj":
|
||||
return f"blk.{layer}.{expert}.ffn_up_exps"
|
||||
else:
|
||||
return f"blk.{layer}.{expert}.ffn_down_exps"
|
||||
|
||||
m = re.match(r"blk\.(\d+)\.mlp\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)", name)
|
||||
if m:
|
||||
layer, expert, proj = m.groups()
|
||||
if proj == "gate_proj":
|
||||
return f"blk.{layer}.{expert}.ffn_gate_exps"
|
||||
elif proj == "up_proj":
|
||||
return f"blk.{layer}.{expert}.ffn_up_exps"
|
||||
else:
|
||||
return f"blk.{layer}.{expert}.ffn_down_exps"
|
||||
|
||||
name = name.replace("lm_head.", "output.")
|
||||
name = name.replace("model.embed_tokens.", "token_embd.")
|
||||
name = name.replace("model.norm.", "output_norm.")
|
||||
|
|
|
|||
|
|
@ -104,7 +104,11 @@ class SafeTensorLoader(ModelLoader):
|
|||
f = self.file_handle_map.get(file)
|
||||
if f is None:
|
||||
raise FileNotFoundError(f"File {file} not found in Safetensor files")
|
||||
tensor = f.get_tensor(key)
|
||||
if use_torch_npu:
|
||||
tensor = f.get_tensor(key).to(torch.float16)
|
||||
else:
|
||||
tensor = f.get_tensor(key)
|
||||
|
||||
return tensor.to(device)
|
||||
|
||||
def load_experts(self, key: str, device: str="cpu"):
|
||||
|
|
|
|||
235
archive/merge_tensors/merge_safetensor_gguf_for_qwen3.py
Normal file
235
archive/merge_tensors/merge_safetensor_gguf_for_qwen3.py
Normal file
|
|
@ -0,0 +1,235 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved.
|
||||
# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import torch
|
||||
from ktransformers.util.custom_loader import GGUFLoader, translate_name_to_gguf
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
def read_safetensor_keys_from_folder(folder_path) -> dict:
|
||||
if not os.path.exists(folder_path):
|
||||
raise FileNotFoundError(f"Safetensors dir not found: {folder_path}")
|
||||
if os.path.isfile(folder_path):
|
||||
folder_path = os.path.dirname(folder_path)
|
||||
|
||||
key_to_file_map = {}
|
||||
found_safetensor = False
|
||||
|
||||
for root, dirs, files in os.walk(folder_path):
|
||||
files = sorted(files)
|
||||
for file in files:
|
||||
if not file.endswith(".safetensors"):
|
||||
continue
|
||||
found_safetensor = True
|
||||
file_path = os.path.join(root, file)
|
||||
try:
|
||||
with safe_open(file_path, framework="pt") as f:
|
||||
for key in f.keys():
|
||||
key_to_file_map[key] = file_path
|
||||
except Exception as e:
|
||||
print(f"Error reading Safetensor file {file_path}: {e}")
|
||||
|
||||
if not found_safetensor:
|
||||
raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
|
||||
|
||||
return key_to_file_map
|
||||
|
||||
|
||||
# 可选:如果你希望对某些非 MoE tensor 也用 GGUF,可以把关键子串填到下面这个列表里
|
||||
tensor_from_gguf = [] # e.g. ["self_attn.q_proj.weight"]
|
||||
|
||||
|
||||
def translate_name(name: str) -> str:
|
||||
name = translate_name_to_gguf(name)
|
||||
name = name.replace(".up_proj.", ".ffn_up_exps.")
|
||||
name = name.replace(".down_proj.", ".ffn_down_exps.")
|
||||
name = name.replace(".gate_proj.", ".ffn_gate_exps.")
|
||||
name = name.replace(".ffn_gate_inp.e_score_correction_bias", ".exp_probs_b.bias")
|
||||
return name
|
||||
|
||||
|
||||
def combine_tensor_sources(safetensor_path: str, gguf_path: str):
|
||||
gguf_loader = GGUFLoader(gguf_path)
|
||||
gguf_tensor_file_map = gguf_loader.tensor_file_map
|
||||
safetensor_tensor_file_map = read_safetensor_keys_from_folder(safetensor_path)
|
||||
|
||||
target_tensor_map = {}
|
||||
|
||||
for key, st_file in safetensor_tensor_file_map.items():
|
||||
if ".mlp.experts." in key and key.endswith(".weight"):
|
||||
parts = key.split(".")
|
||||
if len(parts) < 8:
|
||||
raise ValueError(f"Unexpected MoE expert key format: {key}")
|
||||
norm_key = ".".join(parts[:5] + parts[-2:])
|
||||
|
||||
gguf_name = translate_name(norm_key)
|
||||
if gguf_name not in gguf_tensor_file_map:
|
||||
raise KeyError(
|
||||
f"[MoE] GGUF tensor not found for safetensors key {key} -> {gguf_name}"
|
||||
)
|
||||
target_tensor_map[norm_key] = gguf_tensor_file_map[gguf_name]
|
||||
continue
|
||||
if any(tag in key for tag in tensor_from_gguf):
|
||||
gguf_name = translate_name(key)
|
||||
if gguf_name not in gguf_tensor_file_map:
|
||||
raise KeyError(
|
||||
f"[Non-MoE] GGUF tensor not found for safetensors key {key} -> {gguf_name}"
|
||||
)
|
||||
target_tensor_map[key] = gguf_tensor_file_map[gguf_name]
|
||||
else:
|
||||
target_tensor_map[key] = st_file
|
||||
|
||||
return target_tensor_map, gguf_loader
|
||||
|
||||
|
||||
def write_combined_tensor(target_tensor_map: dict, output_path: str, gguf_loader: GGUFLoader):
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
safetensors_cache = {}
|
||||
layer_groups = defaultdict(list)
|
||||
non_layer_keys = []
|
||||
layer_pattern = re.compile(r"\.layers\.(\d+)\.")
|
||||
|
||||
for key in target_tensor_map:
|
||||
m = layer_pattern.search(key)
|
||||
if m:
|
||||
layer_num = int(m.group(1))
|
||||
layer_groups[layer_num].append(key)
|
||||
else:
|
||||
non_layer_keys.append(key)
|
||||
|
||||
total_shards = len(layer_groups) + (1 if non_layer_keys else 0) - 1
|
||||
if total_shards <= 0:
|
||||
raise ValueError("No tensors to save")
|
||||
|
||||
shard_idx = 0
|
||||
|
||||
if non_layer_keys:
|
||||
tensors = {}
|
||||
for key in non_layer_keys:
|
||||
file_path = target_tensor_map[key]
|
||||
tensor = None
|
||||
ggml_type = None
|
||||
|
||||
if file_path.endswith(".safetensors"):
|
||||
if file_path not in safetensors_cache:
|
||||
safetensors_cache[file_path] = safe_open(file_path, framework="pt")
|
||||
f = safetensors_cache[file_path]
|
||||
tensor = f.get_tensor(key)
|
||||
elif file_path.endswith(".gguf"):
|
||||
gguf_name = translate_name(key)
|
||||
tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file format: {file_path}")
|
||||
|
||||
out_key = translate_name(key)
|
||||
tensors[out_key] = tensor
|
||||
if ggml_type is not None:
|
||||
ggml_type = torch.tensor(ggml_type)
|
||||
if out_key.endswith(".weight"):
|
||||
ggml_key = out_key[:-7] + ".ggml_type"
|
||||
else:
|
||||
ggml_key = out_key + ".ggml_type"
|
||||
tensors[ggml_key] = ggml_type
|
||||
|
||||
output_file = os.path.join(
|
||||
output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors"
|
||||
)
|
||||
print(f"[WRITE] Saving non-layer tensors to {output_file}")
|
||||
save_file(tensors, output_file)
|
||||
shard_idx += 1
|
||||
|
||||
for layer_num in sorted(layer_groups.keys()):
|
||||
layer_keys = layer_groups[layer_num]
|
||||
tensors = {}
|
||||
|
||||
for key in layer_keys:
|
||||
file_path = target_tensor_map[key]
|
||||
tensor = None
|
||||
ggml_type = None
|
||||
|
||||
if file_path.endswith(".safetensors"):
|
||||
if file_path not in safetensors_cache:
|
||||
safetensors_cache[file_path] = safe_open(file_path, framework="pt")
|
||||
f = safetensors_cache[file_path]
|
||||
tensor = f.get_tensor(key)
|
||||
elif file_path.endswith(".gguf"):
|
||||
gguf_name = translate_name(key)
|
||||
tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file format: {file_path}")
|
||||
|
||||
out_key = translate_name(key)
|
||||
tensors[out_key] = tensor
|
||||
if ggml_type is not None:
|
||||
ggml_type = torch.tensor(ggml_type)
|
||||
if out_key.endswith(".weight"):
|
||||
ggml_key = out_key[:-7] + ".ggml_type"
|
||||
else:
|
||||
ggml_key = out_key + ".ggml_type"
|
||||
tensors[ggml_key] = ggml_type
|
||||
|
||||
output_file = os.path.join(
|
||||
output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors"
|
||||
)
|
||||
print(f"[WRITE] Saving layer {layer_num} to {output_file}")
|
||||
save_file(tensors, output_file)
|
||||
shard_idx += 1
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Merge FP8 safetensors and GGUF tensors for Qwen3-30B-A3B"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--safetensor_path",
|
||||
type=str,
|
||||
help="Path to the FP8 Safetensor folder",
|
||||
default="/mnt/data/model/Qwen3-30B-A3B-FP8",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gguf_path",
|
||||
type=str,
|
||||
help="Path to the GGUF file or folder",
|
||||
default="/mnt/data/model/Qwen3-30B-A3B-GGUF",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
help="Path to the output safetensors folder",
|
||||
default="/mnt/data/model/ktrans-safetensors/Qwen3-30B-A3B-q4km-fp8",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("[ARGS]", args)
|
||||
|
||||
safetensor_path = args.safetensor_path
|
||||
gguf_path = args.gguf_path
|
||||
output_path = args.output_path
|
||||
|
||||
target_tensor_map, gguf_loader = combine_tensor_sources(safetensor_path, gguf_path)
|
||||
write_combined_tensor(target_tensor_map, output_path, gguf_loader)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue