mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
support smt and glm4
This commit is contained in:
parent
48bc6185b5
commit
17246bf84f
7 changed files with 129 additions and 16 deletions
|
@ -15,16 +15,14 @@ using ModelName = std::string;
|
||||||
class ModelConfig {
|
class ModelConfig {
|
||||||
public:
|
public:
|
||||||
DimSize hidden_size;
|
DimSize hidden_size;
|
||||||
DimSize intermediate_size;
|
|
||||||
size_t max_position_embeddings;
|
size_t max_position_embeddings;
|
||||||
std::string model_type;
|
|
||||||
size_t num_attention_heads;
|
size_t num_attention_heads;
|
||||||
size_t num_hidden_layers;
|
size_t num_hidden_layers;
|
||||||
size_t num_key_value_heads;
|
size_t num_key_value_heads;
|
||||||
size_t vocab_size;
|
size_t vocab_size;
|
||||||
|
|
||||||
NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, intermediate_size,
|
NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size,
|
||||||
max_position_embeddings, model_type,
|
max_position_embeddings,
|
||||||
num_attention_heads, num_hidden_layers,
|
num_attention_heads, num_hidden_layers,
|
||||||
num_key_value_heads, vocab_size);
|
num_key_value_heads, vocab_size);
|
||||||
|
|
||||||
|
|
114
doc/en/SmallThinker_and_Glm4moe.md
Normal file
114
doc/en/SmallThinker_and_Glm4moe.md
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
# SmallThinker & GLM-4-MoE Support for KTransformers
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
### Overview
|
||||||
|
We are excited to announce that **KTransformers now supports both SmallThinker and GLM-4-MoE**.
|
||||||
|
|
||||||
|
- **SmallThinker-21B (bf16)**: ~26 TPS **on a dual-socket CPU with one consumer-grade GPU**, requiring ~42 GB DRAM.
|
||||||
|
- **GLM-4-MoE 110B (bf16)**: ~11 TPS **on a dual-socket CPU with one consumer-grade GPU**, requiring ~220 GB DRAM.
|
||||||
|
|
||||||
|
### Model & Resource Links
|
||||||
|
- **SmallThinker-21B**
|
||||||
|
- *(to be announced)*
|
||||||
|
- **GLM-4-MoE 110B**
|
||||||
|
- *(to be announced)*
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Installation Guide
|
||||||
|
|
||||||
|
### 1. Resource Requirements
|
||||||
|
|
||||||
|
| Model | Precision | Experts | DRAM Needed | GPU Memory Needed* | TPS (approx.) |
|
||||||
|
|-----------------------|-----------|---------|-------------|--------------------|---------------|
|
||||||
|
| SmallThinker-21B | bf16 | 32 | ~42 GB | 14GB | ~26 TPS |
|
||||||
|
| GLM-4-MoE 110B | bf16 | 128 | ~220 GB | 14GB | ~11 TPS |
|
||||||
|
|
||||||
|
\* Exact GPU memory depends on sequence length, batch size, and kernels used.
|
||||||
|
|
||||||
|
### 2. Prepare Models
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Example: download original safetensors (adjust to your paths/repos)
|
||||||
|
# (Fill in actual repos/filenames yourself)
|
||||||
|
|
||||||
|
# SmallThinker-21B
|
||||||
|
huggingface-cli download --resume-download placeholder-org/Model-TBA \
|
||||||
|
--local-dir ./Model-TBA
|
||||||
|
|
||||||
|
# GLM-4-MoE 110B
|
||||||
|
huggingface-cli download --resume-download placeholder-org/Model-TBA \
|
||||||
|
--local-dir ./Model-TBA
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### 3. Install KTransformers
|
||||||
|
|
||||||
|
Follow the official [Installation Guide](https://kvcache-ai.github.io/ktransformers/en/install.html).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install ktransformers # or from source if you need bleeding-edge features
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Run SmallThinker-21B Inference Server
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python ktransformers/server/main.py \
|
||||||
|
--port 10021 \
|
||||||
|
--model_path /abs/path/to/SmallThinker-21B-bf16 \
|
||||||
|
--model_name SmallThinkerForCausalLM \
|
||||||
|
--optimize_config_path ktransformers/optimize/optimize_rules/SmallThinker-serve.yaml \
|
||||||
|
--max_new_tokens 1024 \
|
||||||
|
--cache_lens 32768 \
|
||||||
|
--chunk_size 256 \
|
||||||
|
--max_batch_size 4 \
|
||||||
|
--backend_type balance_serve
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Run GLM-4-MoE 110B Inference Server
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python ktransformers/server/main.py \
|
||||||
|
--port 10110 \
|
||||||
|
--model_name Glm4MoeForCausalLM \
|
||||||
|
--model_path /abs/path/to/GLM-4-MoE-110B-bf16 \
|
||||||
|
--optimize_config_path ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml \
|
||||||
|
--max_new_tokens 1024 \
|
||||||
|
--cache_lens 32768 \
|
||||||
|
--chunk_size 256 \
|
||||||
|
--max_batch_size 4 \
|
||||||
|
--backend_type balance_serve
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. Access Server
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:10021/v1/chat/completions \
|
||||||
|
-H "accept: application/json" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "hello"}
|
||||||
|
],
|
||||||
|
"model": "SmallThinker-21B",
|
||||||
|
"temperature": 0.3,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"stream": true
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:10110/v1/chat/completions \
|
||||||
|
-H "accept: application/json" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "hello"}
|
||||||
|
],
|
||||||
|
"model": "GLM-4-MoE-110B",
|
||||||
|
"temperature": 0.3,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"stream": true
|
||||||
|
}'
|
||||||
|
```
|
|
@ -24,7 +24,7 @@ model:
|
||||||
type: balance_serve
|
type: balance_serve
|
||||||
# type: ktransformers
|
# type: ktransformers
|
||||||
|
|
||||||
name: SmallthinkerForCausalLM
|
name: SmallThinkerForCausalLM
|
||||||
path: /mnt/data/models/Smallthinker-21B
|
path: /mnt/data/models/Smallthinker-21B
|
||||||
gguf_path: /mnt/data/models/Smallthinker-21B
|
gguf_path: /mnt/data/models/Smallthinker-21B
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ torch.set_grad_enabled(False)
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
import flashinfer
|
import flashinfer
|
||||||
|
|
||||||
class KSmallthinkerForCausalLM(SmallthinkerPreTrainedModel):
|
class KSmallThinkerForCausalLM(SmallthinkerPreTrainedModel):
|
||||||
|
|
||||||
cache: KGQACache
|
cache: KGQACache
|
||||||
use_cuda_graph = False
|
use_cuda_graph = False
|
||||||
|
|
|
@ -143,7 +143,7 @@ class ArgumentParser:
|
||||||
model_config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
model_config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||||
elif args.model_name == "Glm4MoeForCausalLM":
|
elif args.model_name == "Glm4MoeForCausalLM":
|
||||||
model_config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
model_config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||||
elif args.model_name == "SmallthinkerForCausalLM":
|
elif args.model_name == "SmallThinkerForCausalLM":
|
||||||
model_config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
model_config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||||
model_config._attn_implementation = "eager"
|
model_config._attn_implementation = "eager"
|
||||||
else:
|
else:
|
||||||
|
@ -153,7 +153,7 @@ class ArgumentParser:
|
||||||
raise ValueError(f"Model {args.model_name} not supported. Please check your model directory or model name.")
|
raise ValueError(f"Model {args.model_name} not supported. Please check your model directory or model name.")
|
||||||
|
|
||||||
|
|
||||||
if model_config.architectures[0] == "Qwen3MoeForCausalLM" or model_config.architectures[0] == "Qwen2MoeForCausalLM" or model_config.architectures[0] == "SmallthinkerForCausalLM" or model_config.architectures[0] == "Glm4MoeForCausalLM":
|
if model_config.architectures[0] == "Qwen3MoeForCausalLM" or model_config.architectures[0] == "Qwen2MoeForCausalLM" or model_config.architectures[0] == "SmallThinkerForCausalLM" or model_config.architectures[0] == "Glm4MoeForCausalLM":
|
||||||
args.gpu_memory_size = args.cache_lens*2*2*model_config.num_hidden_layers*model_config.num_key_value_heads*model_config.head_dim
|
args.gpu_memory_size = args.cache_lens*2*2*model_config.num_hidden_layers*model_config.num_key_value_heads*model_config.head_dim
|
||||||
args.architectures = model_config.architectures[0]
|
args.architectures = model_config.architectures[0]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -24,7 +24,7 @@ from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausa
|
||||||
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
|
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
|
||||||
from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM
|
from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM
|
||||||
from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM
|
from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM
|
||||||
from ktransformers.models.custom_modeling_smallthinker import KSmallthinkerForCausalLM
|
from ktransformers.models.custom_modeling_smallthinker import KSmallThinkerForCausalLM
|
||||||
from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM
|
from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM
|
||||||
from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig
|
from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig
|
||||||
from ktransformers.models.configuration_smallthinker import SmallthinkerConfig
|
from ktransformers.models.configuration_smallthinker import SmallthinkerConfig
|
||||||
|
@ -138,6 +138,7 @@ class Engine:
|
||||||
elif args.model_name == "SmallThinkerForCausalLM":
|
elif args.model_name == "SmallThinkerForCausalLM":
|
||||||
config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||||
config._attn_implementation = "eager"
|
config._attn_implementation = "eager"
|
||||||
|
config.moe_intermediate_size = config.moe_ffn_hidden_size
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||||
|
@ -164,7 +165,7 @@ class Engine:
|
||||||
self.model = KQwen3MoeForCausalLM(config, self.cache)
|
self.model = KQwen3MoeForCausalLM(config, self.cache)
|
||||||
elif config.architectures[0] == "SmallThinkerForCausalLM":
|
elif config.architectures[0] == "SmallThinkerForCausalLM":
|
||||||
self.cache = KGQACache(config, self.args.page_size)
|
self.cache = KGQACache(config, self.args.page_size)
|
||||||
self.model = KSmallthinkerForCausalLM(config, self.cache)
|
self.model = KSmallThinkerForCausalLM(config, self.cache)
|
||||||
elif config.architectures[0] == "Glm4MoeForCausalLM":
|
elif config.architectures[0] == "Glm4MoeForCausalLM":
|
||||||
self.cache = KGQACache(config, self.args.page_size)
|
self.cache = KGQACache(config, self.args.page_size)
|
||||||
self.model = KGlm4MoeForCausalLM(config, self.cache)
|
self.model = KGlm4MoeForCausalLM(config, self.cache)
|
||||||
|
@ -462,8 +463,8 @@ class BalanceServeInterface(BackendInterfaceBase):
|
||||||
profiler.create_and_start_timer("prefill")
|
profiler.create_and_start_timer("prefill")
|
||||||
|
|
||||||
query_add = sched_ext.QueryAdd()
|
query_add = sched_ext.QueryAdd()
|
||||||
input_ids = torch.tensor([[151331, 151333, 98964, 117392, 103408, 99668, 3837, 99073, 99444,
|
# input_ids = torch.tensor([[151331, 151333, 98964, 117392, 103408, 99668, 3837, 99073, 99444,
|
||||||
99052, 101052, 11314]], device='cuda')
|
# 99052, 101052, 11314]], device='cuda')
|
||||||
query_add.query_token = input_ids[0].tolist()
|
query_add.query_token = input_ids[0].tolist()
|
||||||
query_length = input_ids[0].shape[0]
|
query_length = input_ids[0].shape[0]
|
||||||
query_add.query_length = query_length
|
query_add.query_length = query_length
|
||||||
|
|
|
@ -29,7 +29,7 @@ from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausa
|
||||||
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
|
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
|
||||||
from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM
|
from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM
|
||||||
from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM
|
from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM
|
||||||
from ktransformers.models.custom_modeling_smallthinker import KSmallthinkerForCausalLM
|
from ktransformers.models.custom_modeling_smallthinker import KSmallThinkerForCausalLM
|
||||||
from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM
|
from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM
|
||||||
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
|
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
|
||||||
from ktransformers.server.balance_serve.settings import sched_ext
|
from ktransformers.server.balance_serve.settings import sched_ext
|
||||||
|
@ -55,7 +55,7 @@ def generate_cuda_graphs(chunk_size: int) -> list:
|
||||||
class ModelRunner:
|
class ModelRunner:
|
||||||
"""A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile."""
|
"""A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile."""
|
||||||
|
|
||||||
model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM | KSmallthinkerForCausalLM | KGlm4MoeForCausalLM
|
model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM | KSmallThinkerForCausalLM | KGlm4MoeForCausalLM
|
||||||
input: ForwardBatchInput | list[ForwardBatchInput]
|
input: ForwardBatchInput | list[ForwardBatchInput]
|
||||||
output: ForwardBatchOutput
|
output: ForwardBatchOutput
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ class ModelRunner:
|
||||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
|
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
|
||||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||||
elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallthinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM):
|
elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallThinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM):
|
||||||
self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
|
self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
|
||||||
num_q_heads=self.model.config.num_attention_heads, num_kv_heads=self.model.config.num_key_value_heads,
|
num_q_heads=self.model.config.num_attention_heads, num_kv_heads=self.model.config.num_key_value_heads,
|
||||||
head_dim=self.model.config.head_dim if hasattr(self.model.config, 'head_dim') else self.model.config.hidden_size // self.model.config.num_attention_heads,
|
head_dim=self.model.config.head_dim if hasattr(self.model.config, 'head_dim') else self.model.config.hidden_size // self.model.config.num_attention_heads,
|
||||||
|
@ -126,7 +126,7 @@ class ModelRunner:
|
||||||
num_tokens = self.features_buf[i][0].size(0)
|
num_tokens = self.features_buf[i][0].size(0)
|
||||||
print("capturing cuda graph", batch_size, num_tokens)
|
print("capturing cuda graph", batch_size, num_tokens)
|
||||||
|
|
||||||
if isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallthinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM):
|
if isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallThinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM):
|
||||||
self.model.init_wrapper(self.use_cuda_graph, self.device, num_tokens ,batch_size, self.block_num, i) # TODO: 1024 is a magic number(max_batch_tokens)
|
self.model.init_wrapper(self.use_cuda_graph, self.device, num_tokens ,batch_size, self.block_num, i) # TODO: 1024 is a magic number(max_batch_tokens)
|
||||||
|
|
||||||
self.bsz_tensor_buf[0] = batch_size
|
self.bsz_tensor_buf[0] = batch_size
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue