mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
support smt and glm4
This commit is contained in:
parent
1677e90092
commit
b66d96db97
18 changed files with 3519 additions and 16 deletions
|
@ -24,7 +24,11 @@ from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausa
|
|||
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
|
||||
from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM
|
||||
from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM
|
||||
from ktransformers.models.custom_modeling_smallthinker import KSmallthinkerForCausalLM
|
||||
from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM
|
||||
from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig
|
||||
from ktransformers.models.configuration_smallthinker import SmallthinkerConfig
|
||||
from ktransformers.models.configuration_glm4_moe import Glm4MoeConfig
|
||||
from ktransformers.server.balance_serve.inference.model_runner import ModelRunner
|
||||
from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions
|
||||
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
|
||||
|
@ -60,6 +64,8 @@ default_optimize_rules = {
|
|||
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
|
||||
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-serve.yaml",
|
||||
"Qwen3MoeForCausalLM": ktransformer_rules_dir + "Qwen3Moe-serve.yaml",
|
||||
"SmallthinkerForCausalLM": ktransformer_rules_dir + "Smallthinker-serve.yaml",
|
||||
"Glm4MoeForCausalLM": ktransformer_rules_dir + "Glm4Moe-serve.yaml",
|
||||
}
|
||||
|
||||
|
||||
|
@ -123,15 +129,24 @@ class Engine:
|
|||
self.sched_client = SchedulerClient(args.sched_port)
|
||||
self.updates = []
|
||||
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
except:
|
||||
if args.model_name == "Qwen3Moe":
|
||||
config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
else:
|
||||
assert False, f"model {args.model_name} not supported"
|
||||
print(f"args.model_name: {args.model_name}")
|
||||
|
||||
if args.model_name == "Qwen3MoeForCausalLM":
|
||||
config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
elif args.model_name == "Glm4MoeForCausalLM":
|
||||
config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
elif args.model_name == "SmallthinkerForCausalLM":
|
||||
config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
config._attn_implementation = "eager"
|
||||
else:
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
except:
|
||||
raise ValueError(f"Model {args.model_name} not supported. Please check your model directory or model name.")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
self.gen_queue = generated_token_queue
|
||||
|
||||
with torch.device("meta"):
|
||||
|
@ -147,6 +162,13 @@ class Engine:
|
|||
self.model = KQwen2MoeForCausalLM(config, self.cache)
|
||||
else:
|
||||
self.model = KQwen3MoeForCausalLM(config, self.cache)
|
||||
elif config.architectures[0] == "SmallthinkerForCausalLM":
|
||||
self.cache = KGQACache(config, self.args.page_size)
|
||||
self.model = KSmallthinkerForCausalLM(config, self.cache)
|
||||
elif config.architectures[0] == "Glm4MoeForCausalLM":
|
||||
self.cache = KGQACache(config, self.args.page_size)
|
||||
self.model = KGlm4MoeForCausalLM(config, self.cache)
|
||||
|
||||
|
||||
|
||||
context = zmq.Context()
|
||||
|
@ -197,7 +219,7 @@ class Engine:
|
|||
self.block_num = inference_context.k_cache[0].size(1)
|
||||
self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
|
||||
#@TODO add config
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallthinkerForCausalLM":
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num)
|
||||
else:
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue