mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 11:49:51 +00:00
update: Qwen3 MoE model adaptation for NPU (framework) (#1706)
This commit is contained in:
parent
53f6a6d6e1
commit
adcfa9080f
10 changed files with 867 additions and 174 deletions
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Any, AsyncIterator, List, Optional, Set
|
||||
from ktransformers.models.custom_cache import KVC2StaticCache, KDeepSeekV3Cache, KGQACache
|
||||
from ktransformers.models.custom_cache import KVC2StaticCache, KDeepSeekV3Cache, KGQACache, KVC2Qwen3Cache
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
|
|
@ -39,6 +39,7 @@ except:
|
|||
use_torch_npu = False
|
||||
if use_torch_npu:
|
||||
from ktransformers.models.ascend.custom_ascend_modeling_deepseek_v3 import KNPUDeepseekV3ForCausalLM
|
||||
from ktransformers.models.ascend.custom_ascend_modeling_qwen3 import KNPUQwen3MoeForCausalLM
|
||||
from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel, get_tensor_parallel_group, get_tensor_parallel_size
|
||||
|
||||
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
|
||||
|
|
@ -50,7 +51,7 @@ custom_models = {
|
|||
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
|
||||
"LlamaForCausalLM": LlamaForCausalLM,
|
||||
"MixtralForCausalLM": MixtralForCausalLM,
|
||||
} #TODO 独有?
|
||||
}
|
||||
from ktransformers.server.balance_serve.inference.model_runner import ModelRunner, get_or_create_model_runner #TODO get_or_create_model_runner npu独有?
|
||||
from ktransformers.models.configuration_qwen3_next import Qwen3NextConfig
|
||||
from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions
|
||||
|
|
@ -198,11 +199,15 @@ class Engine:
|
|||
self.cache = KDeepSeekV3Cache(config, self.args.page_size)
|
||||
self.model = KDeepseekV2ForCausalLM(config, self.cache)
|
||||
elif config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||
self.cache = KGQACache(config, self.args.page_size)
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
self.model = KQwen2MoeForCausalLM(config, self.cache)
|
||||
if not use_torch_npu:
|
||||
self.cache = KGQACache(config, self.args.page_size)
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
self.model = KQwen2MoeForCausalLM(config, self.cache)
|
||||
else:
|
||||
self.model = KQwen3MoeForCausalLM(config, self.cache)
|
||||
else:
|
||||
self.model = KQwen3MoeForCausalLM(config, self.cache)
|
||||
self.cache = KVC2Qwen3Cache(config, args.max_batch_size, self.args.page_size)
|
||||
self.model = KNPUQwen3MoeForCausalLM(config, self.cache)
|
||||
elif config.architectures[0] == "SmallThinkerForCausalLM":
|
||||
self.cache = KGQACache(config, self.args.page_size)
|
||||
self.model = KSmallThinkerForCausalLM(config, self.cache)
|
||||
|
|
@ -277,7 +282,11 @@ class Engine:
|
|||
# self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
|
||||
#@TODO add config
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallThinkerForCausalLM" or config.architectures[0] == "Qwen3NextForCausalLM":
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num)
|
||||
if not use_torch_npu:
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num)
|
||||
else:
|
||||
# npu donnot support flash attn
|
||||
self.model.init_wrapper()
|
||||
else:
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
|
||||
|
||||
|
|
@ -322,7 +331,7 @@ class Engine:
|
|||
batch_size = 0
|
||||
for i in range(len(self.batch.decode_mini_batches)):
|
||||
batch_size += len(self.batch.decode_mini_batches[i])
|
||||
logger.debug(f"prefill batch: {len(self.batch.prefill_mini_batches)} decode batch: {len(self.batch.decode_mini_batches)} {batch_size} \n")
|
||||
# logger.debug(f"prefill batch: {len(self.batch.prefill_mini_batches)} decode batch: {len(self.batch.decode_mini_batches)} {batch_size} \n")
|
||||
self.model_runner.run_split(self.batch, self.query_manager)
|
||||
else:
|
||||
self.model_runner.run(self.batch, self.query_manager)
|
||||
|
|
@ -403,9 +412,12 @@ def run_engine(args, token_queue, broadcast_endpoint, event, kvcache_event, rank
|
|||
engine = Engine(args, token_queue, broadcast_endpoint, kvcache_event)
|
||||
if args.use_cuda_graph:
|
||||
if 'npu' in engine.device:
|
||||
print(f"[WARMUP-NPU] start", flush=True)
|
||||
engine.model_runner.warmup_npu()
|
||||
else:
|
||||
engine.model_runner.warmup()
|
||||
else:
|
||||
print(f"[WARMUP-NPU] skip warmup, eager mode!", flush=True)
|
||||
if use_torch_npu:
|
||||
args.port += torch.distributed.get_rank()
|
||||
event.set()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue