update: Qwen3 MoE model adaptation for NPU (framework) (#1706)

This commit is contained in:
Shaoxu Cheng 2025-12-11 17:07:57 +08:00 committed by GitHub
parent 53f6a6d6e1
commit adcfa9080f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 867 additions and 174 deletions

View file

@ -1,5 +1,5 @@
from typing import Any, AsyncIterator, List, Optional, Set
from ktransformers.models.custom_cache import KVC2StaticCache, KDeepSeekV3Cache, KGQACache
from ktransformers.models.custom_cache import KVC2StaticCache, KDeepSeekV3Cache, KGQACache, KVC2Qwen3Cache
from transformers import (
AutoTokenizer,
AutoConfig,
@ -39,6 +39,7 @@ except:
use_torch_npu = False
if use_torch_npu:
from ktransformers.models.ascend.custom_ascend_modeling_deepseek_v3 import KNPUDeepseekV3ForCausalLM
from ktransformers.models.ascend.custom_ascend_modeling_qwen3 import KNPUQwen3MoeForCausalLM
from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel, get_tensor_parallel_group, get_tensor_parallel_size
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
@ -50,7 +51,7 @@ custom_models = {
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"MixtralForCausalLM": MixtralForCausalLM,
} #TODO 独有?
}
from ktransformers.server.balance_serve.inference.model_runner import ModelRunner, get_or_create_model_runner #TODO get_or_create_model_runner npu独有
from ktransformers.models.configuration_qwen3_next import Qwen3NextConfig
from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions
@ -198,11 +199,15 @@ class Engine:
self.cache = KDeepSeekV3Cache(config, self.args.page_size)
self.model = KDeepseekV2ForCausalLM(config, self.cache)
elif config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM":
self.cache = KGQACache(config, self.args.page_size)
if config.architectures[0] == "Qwen2MoeForCausalLM":
self.model = KQwen2MoeForCausalLM(config, self.cache)
if not use_torch_npu:
self.cache = KGQACache(config, self.args.page_size)
if config.architectures[0] == "Qwen2MoeForCausalLM":
self.model = KQwen2MoeForCausalLM(config, self.cache)
else:
self.model = KQwen3MoeForCausalLM(config, self.cache)
else:
self.model = KQwen3MoeForCausalLM(config, self.cache)
self.cache = KVC2Qwen3Cache(config, args.max_batch_size, self.args.page_size)
self.model = KNPUQwen3MoeForCausalLM(config, self.cache)
elif config.architectures[0] == "SmallThinkerForCausalLM":
self.cache = KGQACache(config, self.args.page_size)
self.model = KSmallThinkerForCausalLM(config, self.cache)
@ -277,7 +282,11 @@ class Engine:
# self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
#@TODO add config
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallThinkerForCausalLM" or config.architectures[0] == "Qwen3NextForCausalLM":
self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num)
if not use_torch_npu:
self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num)
else:
# npu donnot support flash attn
self.model.init_wrapper()
else:
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
@ -322,7 +331,7 @@ class Engine:
batch_size = 0
for i in range(len(self.batch.decode_mini_batches)):
batch_size += len(self.batch.decode_mini_batches[i])
logger.debug(f"prefill batch: {len(self.batch.prefill_mini_batches)} decode batch: {len(self.batch.decode_mini_batches)} {batch_size} \n")
# logger.debug(f"prefill batch: {len(self.batch.prefill_mini_batches)} decode batch: {len(self.batch.decode_mini_batches)} {batch_size} \n")
self.model_runner.run_split(self.batch, self.query_manager)
else:
self.model_runner.run(self.batch, self.query_manager)
@ -403,9 +412,12 @@ def run_engine(args, token_queue, broadcast_endpoint, event, kvcache_event, rank
engine = Engine(args, token_queue, broadcast_endpoint, kvcache_event)
if args.use_cuda_graph:
if 'npu' in engine.device:
print(f"[WARMUP-NPU] start", flush=True)
engine.model_runner.warmup_npu()
else:
engine.model_runner.warmup()
else:
print(f"[WARMUP-NPU] skip warmup, eager mode!", flush=True)
if use_torch_npu:
args.port += torch.distributed.get_rank()
event.set()