mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
support qwen3, dont speak human language
This commit is contained in:
parent
f3d842a0ca
commit
3f9bbf1181
30 changed files with 3696 additions and 290 deletions
|
@ -1,5 +1,5 @@
|
|||
from typing import Any, AsyncIterator, List, Optional, Set
|
||||
from ktransformers.models.custom_cache import KDeepSeekV3Cache
|
||||
from ktransformers.models.custom_cache import KDeepSeekV3Cache, KGQACache
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
|
@ -22,6 +22,9 @@ from ktransformers.server.config.log import logger
|
|||
from ktransformers.optimize.optimize import optimize_and_load_gguf
|
||||
from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM
|
||||
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
|
||||
from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM
|
||||
from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM
|
||||
from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig
|
||||
from ktransformers.server.balance_serve.inference.model_runner import ModelRunner
|
||||
from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions
|
||||
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
|
||||
|
@ -53,8 +56,10 @@ ktransformer_rules_dir = (
|
|||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/")
|
||||
)
|
||||
default_optimize_rules = {
|
||||
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
|
||||
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct-serve.yaml",
|
||||
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "Moonlight-16B-A3B-serve.yaml",
|
||||
# "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
|
||||
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-serve.yaml",
|
||||
"Qwen3MoeForCausalLM": ktransformer_rules_dir + "Qwen3Moe-serve.yaml",
|
||||
}
|
||||
|
||||
|
||||
|
@ -105,7 +110,7 @@ class Engine:
|
|||
model_runner: ModelRunner
|
||||
sampler: Sampler
|
||||
query_manager: QueryManager
|
||||
cache: KDeepSeekV3Cache
|
||||
cache: KDeepSeekV3Cache | KGQACache
|
||||
def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None, kvcache_event: Event = None):
|
||||
self.args = args
|
||||
|
||||
|
@ -117,17 +122,32 @@ class Engine:
|
|||
self.device = self.args.device
|
||||
self.sched_client = SchedulerClient(args.sched_port)
|
||||
self.updates = []
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
self.cache = KDeepSeekV3Cache(config, self.args.page_size)
|
||||
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
except:
|
||||
if args.model_name == "Qwen3Moe":
|
||||
config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
else:
|
||||
assert False, f"model {args.model_name} not supported"
|
||||
|
||||
|
||||
self.gen_queue = generated_token_queue
|
||||
|
||||
with torch.device("meta"):
|
||||
if config.architectures[0] == "DeepseekV3ForCausalLM":
|
||||
self.cache = KDeepSeekV3Cache(config, self.args.page_size)
|
||||
self.model = KDeepseekV3ForCausalLM(config, self.cache)
|
||||
elif config.architectures[0] == "DeepseekV2ForCausalLM":
|
||||
self.cache = KDeepSeekV3Cache(config, self.args.page_size)
|
||||
self.model = KDeepseekV2ForCausalLM(config, self.cache)
|
||||
# print(self.block_num)
|
||||
elif config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||
self.cache = KGQACache(config, self.args.page_size)
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
self.model = KQwen2MoeForCausalLM(config, self.cache)
|
||||
else:
|
||||
self.model = KQwen3MoeForCausalLM(config, self.cache)
|
||||
|
||||
|
||||
context = zmq.Context()
|
||||
|
||||
|
@ -176,9 +196,12 @@ class Engine:
|
|||
|
||||
self.block_num = inference_context.k_cache[0].size(1)
|
||||
#@TODO add config
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, 1024 ,args.max_batch_size, self.block_num) # TODO: 1024 is a magic number(max_batch_tokens)
|
||||
else:
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
|
||||
|
||||
self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size)
|
||||
self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
|
||||
self.sampler = Sampler()
|
||||
self.query_manager = QueryManager(device = self.device, page_size = args.page_size)
|
||||
|
||||
|
@ -231,7 +254,7 @@ class Engine:
|
|||
|
||||
if self.batch is not None:
|
||||
self.model_runner.sync()
|
||||
print(f"Model execution time (GPU): {self.model_runner.model_time:.3f} ms")
|
||||
print(f"Model execution time (GPU): {self.model_runner.model_time:.3f} ms, {1000/self.model_runner.model_time:.3f} tokens/s")
|
||||
# if self.rank == 0:
|
||||
|
||||
generated_tokens, probs = self.sampling( self.model_runner.output)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue