diff --git a/csrc/balance_serve/sched/model_config.h b/csrc/balance_serve/sched/model_config.h index e7512c4..78fc8dc 100644 --- a/csrc/balance_serve/sched/model_config.h +++ b/csrc/balance_serve/sched/model_config.h @@ -15,16 +15,14 @@ using ModelName = std::string; class ModelConfig { public: DimSize hidden_size; - DimSize intermediate_size; size_t max_position_embeddings; - std::string model_type; size_t num_attention_heads; size_t num_hidden_layers; size_t num_key_value_heads; size_t vocab_size; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, intermediate_size, - max_position_embeddings, model_type, + NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, + max_position_embeddings, num_attention_heads, num_hidden_layers, num_key_value_heads, vocab_size); diff --git a/doc/en/SmallThinker_and_Glm4moe.md b/doc/en/SmallThinker_and_Glm4moe.md new file mode 100644 index 0000000..20bcc2e --- /dev/null +++ b/doc/en/SmallThinker_and_Glm4moe.md @@ -0,0 +1,114 @@ +# SmallThinker & GLM-4-MoE Support for KTransformers + +## Introduction + +### Overview +We are excited to announce that **KTransformers now supports both SmallThinker and GLM-4-MoE**. + +- **SmallThinker-21B (bf16)**: ~26 TPS **on a dual-socket CPU with one consumer-grade GPU**, requiring ~42 GB DRAM. +- **GLM-4-MoE 110B (bf16)**: ~11 TPS **on a dual-socket CPU with one consumer-grade GPU**, requiring ~220 GB DRAM. + +### Model & Resource Links +- **SmallThinker-21B** + - *(to be announced)* +- **GLM-4-MoE 110B** + - *(to be announced)* + +--- + +## Installation Guide + +### 1. Resource Requirements + +| Model | Precision | Experts | DRAM Needed | GPU Memory Needed* | TPS (approx.) | +|-----------------------|-----------|---------|-------------|--------------------|---------------| +| SmallThinker-21B | bf16 | 32 | ~42 GB | 14GB | ~26 TPS | +| GLM-4-MoE 110B | bf16 | 128 | ~220 GB | 14GB | ~11 TPS | + +\* Exact GPU memory depends on sequence length, batch size, and kernels used. + +### 2. Prepare Models + +```bash +# Example: download original safetensors (adjust to your paths/repos) +# (Fill in actual repos/filenames yourself) + +# SmallThinker-21B +huggingface-cli download --resume-download placeholder-org/Model-TBA \ + --local-dir ./Model-TBA + +# GLM-4-MoE 110B +huggingface-cli download --resume-download placeholder-org/Model-TBA \ + --local-dir ./Model-TBA +``` + + +### 3. Install KTransformers + +Follow the official [Installation Guide](https://kvcache-ai.github.io/ktransformers/en/install.html). + +```bash +pip install ktransformers # or from source if you need bleeding-edge features +``` + +### 4. Run SmallThinker-21B Inference Server + +```bash +python ktransformers/server/main.py \ + --port 10021 \ + --model_path /abs/path/to/SmallThinker-21B-bf16 \ + --model_name SmallThinkerForCausalLM \ + --optimize_config_path ktransformers/optimize/optimize_rules/SmallThinker-serve.yaml \ + --max_new_tokens 1024 \ + --cache_lens 32768 \ + --chunk_size 256 \ + --max_batch_size 4 \ + --backend_type balance_serve +``` + +### 5. Run GLM-4-MoE 110B Inference Server + +```bash +python ktransformers/server/main.py \ + --port 10110 \ + --model_name Glm4MoeForCausalLM \ + --model_path /abs/path/to/GLM-4-MoE-110B-bf16 \ + --optimize_config_path ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml \ + --max_new_tokens 1024 \ + --cache_lens 32768 \ + --chunk_size 256 \ + --max_batch_size 4 \ + --backend_type balance_serve +``` + +### 6. Access Server + +```bash +curl -X POST http://localhost:10021/v1/chat/completions \ + -H "accept: application/json" \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [ + {"role": "user", "content": "hello"} + ], + "model": "SmallThinker-21B", + "temperature": 0.3, + "top_p": 1.0, + "stream": true + }' +``` + +```bash +curl -X POST http://localhost:10110/v1/chat/completions \ + -H "accept: application/json" \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [ + {"role": "user", "content": "hello"} + ], + "model": "GLM-4-MoE-110B", + "temperature": 0.3, + "top_p": 1.0, + "stream": true + }' +``` diff --git a/ktransformers/configs/config.yaml b/ktransformers/configs/config.yaml index 6d0af03..4050c8c 100644 --- a/ktransformers/configs/config.yaml +++ b/ktransformers/configs/config.yaml @@ -24,7 +24,7 @@ model: type: balance_serve # type: ktransformers - name: SmallthinkerForCausalLM + name: SmallThinkerForCausalLM path: /mnt/data/models/Smallthinker-21B gguf_path: /mnt/data/models/Smallthinker-21B diff --git a/ktransformers/models/custom_modeling_smallthinker.py b/ktransformers/models/custom_modeling_smallthinker.py index 759de61..d27e0b0 100644 --- a/ktransformers/models/custom_modeling_smallthinker.py +++ b/ktransformers/models/custom_modeling_smallthinker.py @@ -24,7 +24,7 @@ torch.set_grad_enabled(False) torch.set_default_dtype(torch.bfloat16) import flashinfer -class KSmallthinkerForCausalLM(SmallthinkerPreTrainedModel): +class KSmallThinkerForCausalLM(SmallthinkerPreTrainedModel): cache: KGQACache use_cuda_graph = False diff --git a/ktransformers/server/args.py b/ktransformers/server/args.py index 3b8f58c..6cccac8 100644 --- a/ktransformers/server/args.py +++ b/ktransformers/server/args.py @@ -143,7 +143,7 @@ class ArgumentParser: model_config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True) elif args.model_name == "Glm4MoeForCausalLM": model_config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True) - elif args.model_name == "SmallthinkerForCausalLM": + elif args.model_name == "SmallThinkerForCausalLM": model_config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True) model_config._attn_implementation = "eager" else: @@ -153,7 +153,7 @@ class ArgumentParser: raise ValueError(f"Model {args.model_name} not supported. Please check your model directory or model name.") - if model_config.architectures[0] == "Qwen3MoeForCausalLM" or model_config.architectures[0] == "Qwen2MoeForCausalLM" or model_config.architectures[0] == "SmallthinkerForCausalLM" or model_config.architectures[0] == "Glm4MoeForCausalLM": + if model_config.architectures[0] == "Qwen3MoeForCausalLM" or model_config.architectures[0] == "Qwen2MoeForCausalLM" or model_config.architectures[0] == "SmallThinkerForCausalLM" or model_config.architectures[0] == "Glm4MoeForCausalLM": args.gpu_memory_size = args.cache_lens*2*2*model_config.num_hidden_layers*model_config.num_key_value_heads*model_config.head_dim args.architectures = model_config.architectures[0] else: diff --git a/ktransformers/server/backend/interfaces/balance_serve.py b/ktransformers/server/backend/interfaces/balance_serve.py index 9a869a7..f7c7dc9 100644 --- a/ktransformers/server/backend/interfaces/balance_serve.py +++ b/ktransformers/server/backend/interfaces/balance_serve.py @@ -24,7 +24,7 @@ from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausa from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM -from ktransformers.models.custom_modeling_smallthinker import KSmallthinkerForCausalLM +from ktransformers.models.custom_modeling_smallthinker import KSmallThinkerForCausalLM from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig from ktransformers.models.configuration_smallthinker import SmallthinkerConfig @@ -138,6 +138,7 @@ class Engine: elif args.model_name == "SmallThinkerForCausalLM": config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True) config._attn_implementation = "eager" + config.moe_intermediate_size = config.moe_ffn_hidden_size else: try: config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) @@ -164,7 +165,7 @@ class Engine: self.model = KQwen3MoeForCausalLM(config, self.cache) elif config.architectures[0] == "SmallThinkerForCausalLM": self.cache = KGQACache(config, self.args.page_size) - self.model = KSmallthinkerForCausalLM(config, self.cache) + self.model = KSmallThinkerForCausalLM(config, self.cache) elif config.architectures[0] == "Glm4MoeForCausalLM": self.cache = KGQACache(config, self.args.page_size) self.model = KGlm4MoeForCausalLM(config, self.cache) @@ -462,8 +463,8 @@ class BalanceServeInterface(BackendInterfaceBase): profiler.create_and_start_timer("prefill") query_add = sched_ext.QueryAdd() - input_ids = torch.tensor([[151331, 151333, 98964, 117392, 103408, 99668, 3837, 99073, 99444, - 99052, 101052, 11314]], device='cuda') + # input_ids = torch.tensor([[151331, 151333, 98964, 117392, 103408, 99668, 3837, 99073, 99444, + # 99052, 101052, 11314]], device='cuda') query_add.query_token = input_ids[0].tolist() query_length = input_ids[0].shape[0] query_add.query_length = query_length diff --git a/ktransformers/server/balance_serve/inference/model_runner.py b/ktransformers/server/balance_serve/inference/model_runner.py index 4a46279..75fb169 100644 --- a/ktransformers/server/balance_serve/inference/model_runner.py +++ b/ktransformers/server/balance_serve/inference/model_runner.py @@ -29,7 +29,7 @@ from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausa from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM -from ktransformers.models.custom_modeling_smallthinker import KSmallthinkerForCausalLM +from ktransformers.models.custom_modeling_smallthinker import KSmallThinkerForCausalLM from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM from ktransformers.server.balance_serve.inference.query_manager import QueryManager from ktransformers.server.balance_serve.settings import sched_ext @@ -55,7 +55,7 @@ def generate_cuda_graphs(chunk_size: int) -> list: class ModelRunner: """A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile.""" - model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM | KSmallthinkerForCausalLM | KGlm4MoeForCausalLM + model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM | KSmallThinkerForCausalLM | KGlm4MoeForCausalLM input: ForwardBatchInput | list[ForwardBatchInput] output: ForwardBatchOutput @@ -95,7 +95,7 @@ class ModelRunner: num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True, sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) - elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallthinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM): + elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallThinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM): self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf, num_q_heads=self.model.config.num_attention_heads, num_kv_heads=self.model.config.num_key_value_heads, head_dim=self.model.config.head_dim if hasattr(self.model.config, 'head_dim') else self.model.config.hidden_size // self.model.config.num_attention_heads, @@ -126,7 +126,7 @@ class ModelRunner: num_tokens = self.features_buf[i][0].size(0) print("capturing cuda graph", batch_size, num_tokens) - if isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallthinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM): + if isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallThinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM): self.model.init_wrapper(self.use_cuda_graph, self.device, num_tokens ,batch_size, self.block_num, i) # TODO: 1024 is a magic number(max_batch_tokens) self.bsz_tensor_buf[0] = batch_size