From d03d92ba53e6b23c227b667c3e81ab4614e31f3f Mon Sep 17 00:00:00 2001
From: djw
Date: Fri, 25 Jul 2025 17:22:20 +0000
Subject: [PATCH] support glm4moe
---
.gitignore | 6 +-
README.md | 11 +-
csrc/balance_serve/CMakeLists.txt | 8 +-
csrc/balance_serve/sched/model_config.h | 6 +-
csrc/ktransformers_ext/ext_bindings.cpp | 8 +-
csrc/ktransformers_ext/operators/amx/moe.hpp | 50 +-
.../operators/llamafile/moe.cpp | 31 +-
.../operators/llamafile/moe.h | 5 +-
doc/en/SmallThinker_and_Glm4moe.md | 76 ++
ktransformers/configs/config.yaml | 14 +-
ktransformers/ktransformers | 1 +
.../models/configuration_glm4_moe.py | 242 +++++++
.../models/custom_modeling_glm4_moe.py | 124 ++++
ktransformers/models/modeling_glm4_moe.py | 649 ++++++++++++++++++
ktransformers/operators/RoPE.py | 92 ++-
.../operators/balance_serve_attention.py | 191 +++++-
ktransformers/operators/experts.py | 397 +++++++++++
ktransformers/operators/gate.py | 3 +-
ktransformers/operators/layernorm.py | 90 +++
ktransformers/operators/linear.py | 9 +-
ktransformers/operators/mlp.py | 33 +
.../optimize_rules/Glm4Moe-serve.yaml | 90 +++
ktransformers/server/args.py | 16 +-
.../backend/interfaces/balance_serve.py | 41 +-
.../balance_serve/inference/model_runner.py | 8 +-
.../server/balance_serve/sched_rpc.py | 6 +-
.../server/balance_serve/settings.py | 106 +++
ktransformers/tests/test_speed.py | 4 +-
ktransformers/util/custom_loader.py | 18 +-
pyproject.toml | 2 +-
requirements-local_chat.txt | 2 +-
31 files changed, 2265 insertions(+), 74 deletions(-)
create mode 100644 doc/en/SmallThinker_and_Glm4moe.md
create mode 120000 ktransformers/ktransformers
create mode 100644 ktransformers/models/configuration_glm4_moe.py
create mode 100644 ktransformers/models/custom_modeling_glm4_moe.py
create mode 100644 ktransformers/models/modeling_glm4_moe.py
create mode 100644 ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml
diff --git a/.gitignore b/.gitignore
index 38bb53c..7cb23b1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -26,4 +26,8 @@ ktransformers/tests/chat_txt.txt
mmlu_result*
ktransformers/ktransformers_ext/cuda_musa/
test_prompt.txt
-csrc/demo
\ No newline at end of file
+csrc/demo
+build*
+CMakeFiles/
+kvc2/
+sched/
\ No newline at end of file
diff --git a/README.md b/README.md
index 4f630f3..117a5e0 100644
--- a/README.md
+++ b/README.md
@@ -23,19 +23,14 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
🔥 Updates
+* **July 26, 2025**: Support SmallThinker and GLM4-MoE. ([Tutorial](./doc/en/SmallThinker_and_Glm4moe.md))
* **July 11, 2025**: Support Kimi-K2. ([Tutorial](./doc/en/Kimi-K2.md))
-
* **June 30, 2025**: Support 3-layer (GPU-CPU-Disk) [prefix cache](./doc/en/prefix_cache.md) reuse.
-
* **May 14, 2025**: Support Intel Arc GPU ([Tutorial](./doc/en/xpu.md)).
-
* **Apr 29, 2025**: Support AMX-Int8、 AMX-BF16 and Qwen3MoE ([Tutorial](./doc/en/AMX.md))
https://github.com/user-attachments/assets/fafe8aec-4e22-49a8-8553-59fb5c6b00a2
-
-
-
* **Apr 9, 2025**: Experimental support for LLaMA 4 models ([Tutorial](./doc/en/llama4.md)).
* **Apr 2, 2025**: Support Multi-concurrency. ([Tutorial](./doc/en/balance-serve.md)).
@@ -65,7 +60,7 @@ https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285
- **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 14GB VRAM and 382GB DRAM([Tutorial](./doc/en/DeepseekR1_V3_tutorial.md)).
-
+
- Prefill Speed (tokens/s):
- KTransformers: 54.21 (32 cores) → 74.362 (dual-socket, 2×32 cores) → 255.26 (optimized AMX-based MoE kernel, V0.3 only) → 286.55 (selectively using 6 experts, V0.3 only)
- Compared to 10.31 tokens/s in llama.cpp with 2×32 cores, achieving up to **27.79× speedup**.
@@ -131,7 +126,6 @@ we have already supported vendors:
- Kunpeng
- AMD
-
### 📥 Installation
To install KTransformers, follow the official [Installation Guide](https://kvcache-ai.github.io/ktransformers/en/install.html).
@@ -201,3 +195,4 @@ If you have any questions, feel free to open an issue. Alternatively, you can jo
🙋 FAQ
Some common questions are answered in the [FAQ](doc/en/FAQ.md).
+
diff --git a/csrc/balance_serve/CMakeLists.txt b/csrc/balance_serve/CMakeLists.txt
index 4a78161..08fc430 100644
--- a/csrc/balance_serve/CMakeLists.txt
+++ b/csrc/balance_serve/CMakeLists.txt
@@ -10,10 +10,10 @@ message(STATUS "Using compiler: ${CMAKE_CXX_COMPILER}")
project(balance_serve VERSION 0.1.0)
set(CMAKE_CXX_STANDARD 20)
-# set(CMAKE_CXX_FLAGS "-Og -march=native -Wall -Wextra -g -fPIC")
-# set(CMAKE_BUILD_TYPE "Debug")
-set(CMAKE_CXX_FLAGS "-O3 -march=native -Wall -Wextra -fPIC")
-set(CMAKE_BUILD_TYPE "Release")
+set(CMAKE_CXX_FLAGS "-Og -march=native -Wall -Wextra -g -fPIC")
+set(CMAKE_BUILD_TYPE "Debug")
+# set(CMAKE_CXX_FLAGS "-O3 -march=native -Wall -Wextra -fPIC")
+# set(CMAKE_BUILD_TYPE "Release")
if(NOT DEFINED _GLIBCXX_USE_CXX11_ABI)
diff --git a/csrc/balance_serve/sched/model_config.h b/csrc/balance_serve/sched/model_config.h
index e7512c4..78fc8dc 100644
--- a/csrc/balance_serve/sched/model_config.h
+++ b/csrc/balance_serve/sched/model_config.h
@@ -15,16 +15,14 @@ using ModelName = std::string;
class ModelConfig {
public:
DimSize hidden_size;
- DimSize intermediate_size;
size_t max_position_embeddings;
- std::string model_type;
size_t num_attention_heads;
size_t num_hidden_layers;
size_t num_key_value_heads;
size_t vocab_size;
- NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, intermediate_size,
- max_position_embeddings, model_type,
+ NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size,
+ max_position_embeddings,
num_attention_heads, num_hidden_layers,
num_key_value_heads, vocab_size);
diff --git a/csrc/ktransformers_ext/ext_bindings.cpp b/csrc/ktransformers_ext/ext_bindings.cpp
index f0aeaa5..a6a717b 100644
--- a/csrc/ktransformers_ext/ext_bindings.cpp
+++ b/csrc/ktransformers_ext/ext_bindings.cpp
@@ -683,12 +683,12 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
py::class_(moe_module, "MOEConfig")
.def(py::init([](int expert_num, int routed_expert_num, int hidden_size,
int intermediate_size, int stride, int group_min_len,
- int group_max_len, intptr_t gate_proj,
+ int group_max_len, bool use_silu, intptr_t gate_proj,
intptr_t up_proj, intptr_t down_proj, int gate_type,
int up_type, int down_type, int hidden_type) {
return MOEConfig(expert_num, routed_expert_num, hidden_size,
intermediate_size, stride, group_min_len,
- group_max_len, (void *)gate_proj, (void *)up_proj,
+ group_max_len, use_silu, (void *)gate_proj, (void *)up_proj,
(void *)down_proj, (ggml_type)gate_type,
(ggml_type)up_type, (ggml_type)down_type,
(ggml_type)hidden_type);
@@ -703,11 +703,11 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
py::class_(moe_module, "AMX_MOEConfig")
.def(py::init([](int expert_num, int routed_expert_num, int hidden_size,
int intermediate_size,
- int max_len, intptr_t gate_proj,
+ int max_len, bool use_silu, intptr_t gate_proj,
intptr_t up_proj, intptr_t down_proj) {
return AMX_MOEConfig(expert_num, routed_expert_num, hidden_size,
intermediate_size,
- max_len, (void *)gate_proj,
+ max_len, use_silu, (void *)gate_proj,
(void *)up_proj, (void *)down_proj);
}));
diff --git a/csrc/ktransformers_ext/operators/amx/moe.hpp b/csrc/ktransformers_ext/operators/amx/moe.hpp
index 81df642..662189f 100644
--- a/csrc/ktransformers_ext/operators/amx/moe.hpp
+++ b/csrc/ktransformers_ext/operators/amx/moe.hpp
@@ -69,22 +69,29 @@ static inline __m512 act_fn(__m512 gate_val, __m512 up_val) {
return _mm512_mul_ps(act_val, up_val);
}
+static inline __m512 relu_act_fn(__m512 gate_val, __m512 up_val) {
+ __m512 zero_vec = _mm512_setzero_ps();
+ __m512 act_val = _mm512_max_ps(zero_vec, gate_val);
+ return _mm512_mul_ps(act_val, up_val);
+}
+
struct AMX_MOEConfig {
int expert_num;
int routed_expert_num;
int hidden_size;
int intermediate_size;
int max_len;
+ bool use_silu;
void *gate_proj;
void *up_proj;
void *down_proj;
AMX_MOEConfig() {}
- AMX_MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int max_len,
+ AMX_MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int max_len, bool use_silu,
void *gate_proj, void *up_proj, void *down_proj)
: expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size),
- intermediate_size(intermediate_size), max_len(max_len), gate_proj(gate_proj), up_proj(up_proj),
+ intermediate_size(intermediate_size), max_len(max_len), use_silu(use_silu), gate_proj(gate_proj), up_proj(up_proj),
down_proj(down_proj) {}
};
@@ -336,18 +343,35 @@ public:
gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);
up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
- for (int i = 0; i < m_local_num_[expert_idx]; i++) {
- ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
- ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
- for (int j = n_start; j < n_end; j += 32) {
- __m512 gate_val0, gate_val1, up_val0, up_val1;
- avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1);
- avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1);
- __m512 result0 = act_fn(gate_val0, up_val0);
- __m512 result1 = act_fn(gate_val1, up_val1);
- avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j));
- }
+ if (config_.use_silu) {
+ for (int i = 0; i < m_local_num_[expert_idx]; i++) {
+ ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
+ ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
+ for (int j = n_start; j < n_end; j += 32) {
+ __m512 gate_val0, gate_val1, up_val0, up_val1;
+ avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1);
+ avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1);
+ __m512 result0 = act_fn(gate_val0, up_val0);
+ __m512 result1 = act_fn(gate_val1, up_val1);
+ avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j));
+ }
+ }
}
+ else {
+ for (int i = 0; i < m_local_num_[expert_idx]; i++) {
+ ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
+ ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
+ for (int j = n_start; j < n_end; j += 32) {
+ __m512 gate_val0, gate_val1, up_val0, up_val1;
+ avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1);
+ avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1);
+ __m512 result0 = relu_act_fn(gate_val0, up_val0);
+ __m512 result1 = relu_act_fn(gate_val1, up_val1);
+ avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j));
+ }
+ }
+ }
+
},
nullptr);
backend->do_work_stealing_job(
diff --git a/csrc/ktransformers_ext/operators/llamafile/moe.cpp b/csrc/ktransformers_ext/operators/llamafile/moe.cpp
index cd42691..86e55a2 100644
--- a/csrc/ktransformers_ext/operators/llamafile/moe.cpp
+++ b/csrc/ktransformers_ext/operators/llamafile/moe.cpp
@@ -10,6 +10,7 @@
#include "moe.h"
#include
#include
+#include
#ifdef USE_NUMA
#include
@@ -134,6 +135,14 @@ static float act_fn(float x) {
return x / (1.0f + expf(-x));
}
+static float act_fn_relu(float x) {
+ if(x > 0.0){
+ return x;
+ } else {
+ return 0.0;
+ }
+}
+
void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend) {
const void* gate_input_ptr;
const void* up_input_ptr;
@@ -182,8 +191,16 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c
float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride;
llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
- for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
- s_intermediate_fp32_[expert_idx][i] = act_fn(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i];
+ if(config_.use_silu){
+ // use silu as act fn
+ for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
+ s_intermediate_fp32_[expert_idx][i] = act_fn(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i];
+ }
+ } else {
+ // use relu as act fn
+ for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
+ s_intermediate_fp32_[expert_idx][i] = act_fn_relu(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i];
+ }
}
if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) == 0) {
float* intermediate_fp32_ptr = s_intermediate_fp32_[expert_idx] + ith * config_.stride;
@@ -304,8 +321,14 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * stride;
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
- for (int j = ith * stride; j < (ith + 1) * stride; j++) {
- m_local_intermediate_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = act_fn(m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size + j]) * m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size + j];
+ if(config_.use_silu){
+ for (int j = ith * stride; j < (ith + 1) * stride; j++) {
+ m_local_intermediate_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = act_fn(m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size + j]) * m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size + j];
+ }
+ } else {
+ for (int j = ith * stride; j < (ith + 1) * stride; j++) {
+ m_local_intermediate_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = act_fn_relu(m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size + j]) * m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size + j];
+ }
}
float* intermediate_fp32_ptr = m_local_intermediate_fp32_ptr_[expert_idx] + i * config_.intermediate_size + ith * stride;
void* down_input_ptr = m_local_down_input_ptr_[expert_idx] + i * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) + ith * stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
diff --git a/csrc/ktransformers_ext/operators/llamafile/moe.h b/csrc/ktransformers_ext/operators/llamafile/moe.h
index 28d7ad3..b568b51 100644
--- a/csrc/ktransformers_ext/operators/llamafile/moe.h
+++ b/csrc/ktransformers_ext/operators/llamafile/moe.h
@@ -32,6 +32,7 @@ struct MOEConfig {
int stride;
int group_min_len;
int group_max_len;
+ bool use_silu;
void* gate_proj;
void* up_proj;
void* down_proj;
@@ -42,8 +43,8 @@ struct MOEConfig {
MOEConfig() {}
- MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int stride, int group_min_len, int group_max_len, void* gate_proj, void* up_proj, void* down_proj, ggml_type gate_type, ggml_type up_type, ggml_type down_type, ggml_type hidden_type)
- : expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size), intermediate_size(intermediate_size), stride(stride), group_min_len(group_min_len), group_max_len(group_max_len), gate_proj(gate_proj), up_proj(up_proj), down_proj(down_proj), gate_type(gate_type), up_type(up_type), down_type(down_type), hidden_type(hidden_type) {}
+ MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int stride, int group_min_len, int group_max_len, bool use_silu, void* gate_proj, void* up_proj, void* down_proj, ggml_type gate_type, ggml_type up_type, ggml_type down_type, ggml_type hidden_type)
+ : expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size), intermediate_size(intermediate_size), stride(stride), group_min_len(group_min_len), group_max_len(group_max_len), use_silu(use_silu), gate_proj(gate_proj), up_proj(up_proj), down_proj(down_proj), gate_type(gate_type), up_type(up_type), down_type(down_type), hidden_type(hidden_type) {}
};
class MOE {
diff --git a/doc/en/SmallThinker_and_Glm4moe.md b/doc/en/SmallThinker_and_Glm4moe.md
new file mode 100644
index 0000000..611f4ed
--- /dev/null
+++ b/doc/en/SmallThinker_and_Glm4moe.md
@@ -0,0 +1,76 @@
+# GLM-4-MoE Support for KTransformers
+
+## Introduction
+
+### Overview
+We are excited to announce that **KTransformers now supports GLM-4-MoE**.
+
+- **GLM-4-MoE 110B (bf16)**: ~11 TPS **on a dual-socket CPU with one consumer-grade GPU**, requiring ~440 GB DRAM.
+- **GLM-4-MoE 110B (AMX INT8)**: prefill ~309 TPS / decode ~16 TPS **on a dual-socket CPU with one consumer-grade GPU**, requiring ~220 GB DRAM.
+
+### Model & Resource Links
+- **GLM-4-MoE 110B**
+ - *(to be announced)*
+
+## Installation Guide
+
+### 1. Resource Requirements
+
+| Model | Precision | Experts | DRAM Needed | GPU Memory Needed\* | TPS (approx.) |
+| ------------------------- | --------- | ------- | ----------- | ------------------- | ------------------------------ |
+| GLM-4-MoE 110B | bf16 | 128 | \~440 GB | 14 GB | \~11 TPS |
+| GLM-4-MoE 110B (AMX INT8) | int8 | 128 | \~220 GB | 14 GB | prefill \~309 TPS / decode \~16 TPS |
+
+\* Exact GPU memory depends on sequence length, batch size, and kernels used.
+
+### 2. Prepare Models
+
+```bash
+# Example: download original safetensors (adjust to your paths/repos)
+# (Fill in actual repos/filenames yourself)
+
+# GLM-4-MoE 110B
+huggingface-cli download --resume-download placeholder-org/Model-TBA \
+ --local-dir ./Model-TBA
+````
+
+### 3. Install KTransformers
+
+Follow the official Installation Guide.
+
+```bash
+pip install ktransformers # or from source if you need bleeding-edge features
+```
+
+### 4. Run GLM-4-MoE 110B Inference Server
+
+```bash
+python ktransformers/server/main.py \
+ --port 10110 \
+ --model_name Glm4MoeForCausalLM \
+ --model_path /abs/path/to/GLM-4-MoE-110B-bf16 \
+ --optimize_config_path ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml \
+ --max_new_tokens 1024 \
+ --cache_lens 32768 \
+ --chunk_size 256 \
+ --max_batch_size 4 \
+ --backend_type balance_serve
+```
+
+### 5. Access Server
+
+```bash
+curl -X POST http://localhost:10110/v1/chat/completions \
+ -H "accept: application/json" \
+ -H "Content-Type: application/json" \
+ -d '{
+ "messages": [
+ {"role": "user", "content": "hello"}
+ ],
+ "model": "GLM-4-MoE-110B",
+ "temperature": 0.3,
+ "top_p": 1.0,
+ "stream": true
+ }'
+```
+
diff --git a/ktransformers/configs/config.yaml b/ktransformers/configs/config.yaml
index 3bd60f9..4050c8c 100644
--- a/ktransformers/configs/config.yaml
+++ b/ktransformers/configs/config.yaml
@@ -21,12 +21,12 @@ user:
model:
# type: transformers
- # type: balance_serve
- type: ktransformers
+ type: balance_serve
+ # type: ktransformers
- name: DeepSeek-Coder-V2-Instruct
- path: deepseek-ai/DeepSeek-V2-Lite-Chat
- gguf_path: ./DeepSeek-V2-Lite-Chat-GGUF
+ name: SmallThinkerForCausalLM
+ path: /mnt/data/models/Smallthinker-21B
+ gguf_path: /mnt/data/models/Smallthinker-21B
device: cuda:0
cache_lens: 16384
@@ -67,7 +67,7 @@ attn:
page_size: 256
chunk_size: 256
kvc2:
- gpu_only: false
+ gpu_only: true
utilization_percentage: 1.0
cpu_memory_size_GB: 500
- disk_path: /mnt/data/kvc
\ No newline at end of file
+ disk_path: /home/wjh/kvc
\ No newline at end of file
diff --git a/ktransformers/ktransformers b/ktransformers/ktransformers
new file mode 120000
index 0000000..598751a
--- /dev/null
+++ b/ktransformers/ktransformers
@@ -0,0 +1 @@
+/home/djw/py311_717/ktransformers/ktransformers
\ No newline at end of file
diff --git a/ktransformers/models/configuration_glm4_moe.py b/ktransformers/models/configuration_glm4_moe.py
new file mode 100644
index 0000000..9906ab0
--- /dev/null
+++ b/ktransformers/models/configuration_glm4_moe.py
@@ -0,0 +1,242 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/glm4_moe/modular_glm4_moe.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_glm4_moe.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.modeling_rope_utils import rope_config_validation
+
+
+class Glm4MoeConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Glm4MoeModel`]. It is used to instantiate a
+ Glm4Moe model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of [THUDM/GLM-4-100B-A10B](https://huggingface.co/THUDM/GLM-4-100B-A10B).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 151552):
+ Vocabulary size of the Glm4Moe model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Glm4MoeModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 10944):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 46):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 96):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ partial_rotary_factor (`float`, *optional*, defaults to 0.5):
+ The factor of the partial rotary position.
+ num_key_value_heads (`int`, *optional*, defaults to 8):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
+
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 131072):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ moe_intermediate_size (`int`, *optional*, defaults to 1408):
+ Intermediate size of the routed expert.
+ num_experts_per_tok (`int`, *optional*, defaults to 8):
+ number of experts per token.
+ n_shared_experts (`int`, *optional*, defaults to 1):
+ Number of shared experts.
+ n_routed_experts (`int`, *optional*, defaults to 128):
+ Number of routed experts.
+ routed_scaling_factor (`float`, *optional*, defaults to 1.0):
+ Scaling factor or routed experts.
+ n_group (`int`, *optional*, defaults to 1):
+ Number of groups for routed experts.
+ topk_group (`int`, *optional*, defaults to 1):
+ Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
+ first_k_dense_replace (`int`, *optional*, defaults to 1):
+ Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
+ \--k dense layers--/
+ norm_topk_prob (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the topk probabilities.
+ use_qk_norm (`bool`, *optional*, defaults to `False`):
+ Whether to use query-key normalization in the attention
+ ```python
+ >>> from transformers import Glm4MoeModel, Glm4MoeConfig
+
+ >>> # Initializing a Glm4Moe style configuration
+ >>> configuration = Glm4MoeConfig()
+
+ >>> # Initializing a model from the GLM-4-MOE-100B-A10B style configuration
+ >>> model = Glm4MoeModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "glm4_moe"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ # Default tensor parallel plan for base model `Glm4Moe`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.experts.*.gate_proj": "colwise",
+ "layers.*.mlp.experts.*.up_proj": "colwise",
+ "layers.*.mlp.experts.*.down_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=151552,
+ hidden_size=4096,
+ intermediate_size=10944,
+ num_hidden_layers=46,
+ num_attention_heads=96,
+ partial_rotary_factor=0.5,
+ num_key_value_heads=8,
+ hidden_act="silu",
+ max_position_embeddings=131072,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ moe_intermediate_size=1408,
+ num_experts_per_tok=8,
+ n_shared_experts=1,
+ n_routed_experts=128,
+ routed_scaling_factor=1.0,
+ n_group=1,
+ topk_group=1,
+ first_k_dense_replace=1,
+ norm_topk_prob=True,
+ use_qk_norm=False,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.partial_rotary_factor = partial_rotary_factor
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, move it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+
+ # MoE arguments
+ self.moe_intermediate_size = moe_intermediate_size
+ self.num_experts_per_tok = num_experts_per_tok
+ self.n_group = n_group
+ self.topk_group = topk_group
+ self.n_shared_experts = n_shared_experts
+ self.n_routed_experts = n_routed_experts
+ self.routed_scaling_factor = routed_scaling_factor
+ self.first_k_dense_replace = first_k_dense_replace
+ self.norm_topk_prob = norm_topk_prob
+ self.use_qk_norm = use_qk_norm
+
+ super().__init__(
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+__all__ = ["Glm4MoeConfig"]
\ No newline at end of file
diff --git a/ktransformers/models/custom_modeling_glm4_moe.py b/ktransformers/models/custom_modeling_glm4_moe.py
new file mode 100644
index 0000000..f33c177
--- /dev/null
+++ b/ktransformers/models/custom_modeling_glm4_moe.py
@@ -0,0 +1,124 @@
+"""
+Date: 2024-11-06 10:05:11
+LastEditors: djw
+LastEditTime: 2024-11-13 07:50:51
+"""
+
+import math
+from dataclasses import dataclass
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+import math
+from typing import List, Optional, Tuple, Union
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput
+from ktransformers.models.custom_cache import KGQACache
+from ktransformers.models.modeling_glm4_moe import Glm4MoeModel, Glm4MoePreTrainedModel
+from ktransformers.models.configuration_glm4_moe import Glm4MoeConfig
+from ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn
+
+torch.set_grad_enabled(False)
+torch.set_default_dtype(torch.bfloat16)
+import flashinfer
+
+class KGlm4MoeForCausalLM(Glm4MoePreTrainedModel):
+
+ cache: KGQACache
+ use_cuda_graph = False
+ def __init__(
+ self,
+ config: Glm4MoeConfig,
+ cache,
+ ):
+
+ super().__init__(config)
+ self.model = Glm4MoeModel(config)
+ self.config = config
+ self.cache = cache
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.attn = [None] * 100
+
+ def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0):
+ self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device)
+
+
+ def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"):
+ features = []
+ for i in range(batch.batch_size):
+ tokens = batch.minibatch.tokens.contiguous()
+ feature = (
+ self.model.embed_tokens(tokens.to(torch.device('cpu')))
+ .to(torch.bfloat16)
+ .to(device=device)
+ )
+ features.append(feature)
+
+ return features
+
+
+ def forward(
+ self,
+ batch: ForwardBatchInput | None = None,
+ features: List[torch.Tensor] | None = None,
+ bsz_tensors: torch.Tensor | None = None,
+ num_tokens_tensors: torch.Tensor | None = None,
+ page_idx: torch.Tensor | None = None,
+ page_offset: torch.Tensor | None = None,
+ cuda_graph_idx: int | None = 0
+ ) -> ForwardBatchOutput:
+ current_stream = torch.cuda.current_stream()
+
+ forward_batch_output = ForwardBatchOutput()
+
+
+ hidden_states = features[0]
+ self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0])
+
+ freqs_cis = self.model.rotary_emb(hidden_states.unsqueeze(0), batch.minibatch.position_ids.unsqueeze(0))
+
+
+ with torch.cuda.stream(current_stream):
+ residual = torch.zeros_like(hidden_states)
+ for i, decode_layer in enumerate(self.model.layers):
+
+ hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)
+ hidden_states = decode_layer.self_attn(hidden_states, self.cache,
+ freqs_cis,
+ wrapper=self.attn[cuda_graph_idx], bsz_tensors=num_tokens_tensors,
+ position_ids=batch.minibatch.position_ids
+ )
+
+ hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)
+ if i < self.model.config.first_k_dense_replace:
+ hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors)
+ else:
+ hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors, cuda_graph_idx)
+ # hidden_states = hidden_states.squeeze(0)
+
+ forward_batch_output = ForwardBatchOutput()
+ with torch.cuda.stream(current_stream):
+ local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors)
+ forward_batch_output.logits.append(local_logit)
+
+ return forward_batch_output
+
+
+
+ def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors,
+ num_q_heads: int,
+ num_kv_heads: int,
+ head_dim: int,
+ page_size: int,
+ causal: bool,
+ q_data_type: torch.dtype,
+ kv_data_type: torch.dtype,
+ cuda_graph_idx: int = 0
+ ):
+ minibatch = batch.minibatch
+ self.attn[cuda_graph_idx].plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices,
+ minibatch.kv_last_page_len, bsz_tensors, num_tokens_tensors, num_q_heads, num_kv_heads, head_dim, page_size, causal=causal, q_data_type=q_data_type, kv_data_type=kv_data_type)
+
\ No newline at end of file
diff --git a/ktransformers/models/modeling_glm4_moe.py b/ktransformers/models/modeling_glm4_moe.py
new file mode 100644
index 0000000..32727a8
--- /dev/null
+++ b/ktransformers/models/modeling_glm4_moe.py
@@ -0,0 +1,649 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/glm4_moe/modular_glm4_moe.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_glm4_moe.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache
+from transformers.generation import GenerationMixin
+# from transformers.integrations import use_kernel_forward_from_hub
+from transformers.masking_utils import create_causal_mask
+from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
+from transformers.modeling_layers import GradientCheckpointingLayer
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from transformers.processing_utils import Unpack
+# from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
+from transformers.utils import auto_docstring, can_return_tuple
+# from transformers.utils.generic import check_model_inputs
+from .configuration_glm4_moe import Glm4MoeConfig
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ # **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+
+ # Keep half or full tensor for later concatenation
+ rotary_dim = cos.shape[-1]
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
+
+ # Apply rotary embeddings on the first half or full tensor
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
+
+ # Concatenate back to full shape
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
+ return q_embed, k_embed
+
+
+class Glm4MoeAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Glm4MoeConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.use_qk_norm = config.use_qk_norm
+ if self.use_qk_norm:
+ self.q_norm = Glm4MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
+ self.k_norm = Glm4MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_value: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
+ key_states = self.k_proj(hidden_states).view(hidden_shape)
+ value_states = self.v_proj(hidden_states).view(hidden_shape)
+
+ if self.use_qk_norm: # main diff from Llama
+ query_states = self.q_norm(query_states)
+ key_states = self.k_norm(key_states)
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Glm4MoeMLP(nn.Module):
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
+
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class Glm4MoeTopkRouter(nn.Module):
+ def __init__(self, config: Glm4MoeConfig):
+ super().__init__()
+ self.config = config
+ self.top_k = config.num_experts_per_tok
+ self.n_routed_experts = config.n_routed_experts
+ self.routed_scaling_factor = config.routed_scaling_factor
+ self.n_group = config.n_group
+ self.topk_group = config.topk_group
+ self.norm_topk_prob = config.norm_topk_prob
+
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
+ self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts), dtype=torch.float32))
+
+ @torch.no_grad()
+ def get_topk_indices(self, scores):
+ scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
+ group_scores = (
+ scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
+ .topk(2, dim=-1)[0]
+ .sum(dim=-1)
+ )
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
+ group_mask = torch.zeros_like(group_scores)
+ group_mask.scatter_(1, group_idx, 1)
+ score_mask = (
+ group_mask.unsqueeze(-1)
+ .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
+ .reshape(-1, self.n_routed_experts)
+ )
+ scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
+ topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
+ return topk_indices
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.view(-1, self.config.hidden_size)
+ router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
+ scores = router_logits.sigmoid()
+ topk_indices = self.get_topk_indices(scores)
+ topk_weights = scores.gather(1, topk_indices)
+ if self.norm_topk_prob:
+ denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
+ topk_weights /= denominator
+ topk_weights = topk_weights * self.routed_scaling_factor
+ return topk_indices, topk_weights
+
+
+# @use_kernel_forward_from_hub("RMSNorm")
+class Glm4MoeRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Glm4MoeRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class Glm4MoeMoE(nn.Module):
+ """
+ A mixed expert module containing shared experts.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.experts = nn.ModuleList(
+ [
+ Glm4MoeMLP(config, intermediate_size=config.moe_intermediate_size)
+ for _ in range(config.n_routed_experts)
+ ]
+ )
+ self.gate = Glm4MoeTopkRouter(config)
+ self.shared_experts = Glm4MoeMLP(
+ config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
+ )
+
+ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
+ r"""
+ CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
+ to not have to do a loop here (deepseek has 256 experts soooo yeah).
+ """
+ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
+ expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
+ expert_mask = expert_mask.permute(2, 0, 1)
+
+ for expert_idx in range(len(self.experts)):
+ expert = self.experts[expert_idx]
+ mask = expert_mask[expert_idx]
+ token_indices, weight_indices = torch.where(mask)
+
+ if token_indices.numel() > 0:
+ expert_weights = topk_weights[token_indices, weight_indices]
+ expert_input = hidden_states[token_indices]
+ expert_output = expert(expert_input)
+ weighted_output = expert_output * expert_weights.unsqueeze(-1)
+ final_hidden_states.index_add_(0, token_indices, weighted_output)
+
+ # in original deepseek, the output of the experts are gathered once we leave this module
+ # thus the moe module is itelsf an IsolatedParallel module
+ # and all expert are "local" meaning we shard but we don't gather
+ return final_hidden_states.type(hidden_states.dtype)
+
+ def forward(self, hidden_states):
+ residuals = hidden_states
+ orig_shape = hidden_states.shape
+ topk_indices, topk_weights = self.gate(hidden_states)
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+ hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
+ hidden_states = hidden_states + self.shared_experts(residuals)
+ return hidden_states
+
+
+class Glm4MoeDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Glm4MoeConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = Glm4MoeAttention(config=config, layer_idx=layer_idx)
+
+ if layer_idx >= config.first_k_dense_replace:
+ self.mlp = Glm4MoeMoE(config)
+ else:
+ self.mlp = Glm4MoeMLP(config)
+
+ self.input_layernorm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ # **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor]:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@auto_docstring
+class Glm4MoePreTrainedModel(PreTrainedModel):
+ config: Glm4MoeConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Glm4MoeDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _supports_static_cache = False
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": Glm4MoeDecoderLayer,
+ "attentions": Glm4MoeAttention,
+ }
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, Glm4MoeRMSNorm):
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, Glm4MoeTopkRouter):
+ module.weight.data.normal_(mean=0.0, std=std)
+
+
+class Glm4MoeRotaryEmbedding(nn.Module):
+ def __init__(self, config: Glm4MoeConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+@auto_docstring
+class Glm4MoeModel(Glm4MoePreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"model\.layers\.92.*", r"model\.layers\.46.*"]
+
+ def __init__(self, config: Glm4MoeConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [Glm4MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Glm4MoeRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ # @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ # **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache()
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position: torch.Tensor = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+@auto_docstring
+class Glm4MoeForCausalLM(Glm4MoePreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Glm4MoeModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ # **kwargs: Unpack[TransformersKwargs],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, Glm4MoeForCausalLM
+
+ >>> model = Glm4MoeForCausalLM.from_pretrained("meta-glm4_moe/Glm4Moe-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-glm4_moe/Glm4Moe-2-7b-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ # **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = ["Glm4MoePreTrainedModel", "Glm4MoeModel", "Glm4MoeForCausalLM"]
\ No newline at end of file
diff --git a/ktransformers/operators/RoPE.py b/ktransformers/operators/RoPE.py
index 85d6556..968c7b9 100644
--- a/ktransformers/operators/RoPE.py
+++ b/ktransformers/operators/RoPE.py
@@ -26,6 +26,8 @@ from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_loader import GGUFLoader
from ktransformers.util.utils import InferenceState
from transformers.configuration_utils import PretrainedConfig
+from ktransformers.models.modeling_smallthinker import SmallthinkerRotaryEmbedding
+from ktransformers.models.modeling_glm4_moe import Glm4MoeRotaryEmbedding
import torch
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
@@ -437,4 +439,92 @@ class KQwen3MoeRotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
def load(self):
self.orig_module.__init__(
self.orig_module.config
- )
\ No newline at end of file
+ )
+
+
+class KSmallthinkerRotaryEmbedding(BaseInjectedModule, SmallthinkerRotaryEmbedding):
+ def __init__(
+ self,
+ key: str,
+ gguf_loader: GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module,
+ # device: str = "cuda",
+ generate_device: str = "cuda",
+ prefill_device: str = "cuda",
+ **kwargs,
+ ):
+ BaseInjectedModule.__init__(
+ self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
+ )
+ self.orig_module.__init__(
+ config
+ )
+ self.generate_device = generate_device
+ self.prefill_device = prefill_device
+
+ def load(self):
+ self.orig_module.__init__(
+ self.orig_module.config,
+ device = self.generate_device,
+ )
+
+ @torch.no_grad()
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+class KGlm4MoeRotaryEmbedding(BaseInjectedModule, Glm4MoeRotaryEmbedding):
+ def __init__(
+ self,
+ key: str,
+ gguf_loader: GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module,
+ # device: str = "cuda",
+ generate_device: str = "cuda",
+ prefill_device: str = "cuda",
+ **kwargs,
+ ):
+ BaseInjectedModule.__init__(
+ self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
+ )
+ self.orig_module.__init__(
+ config
+ )
+ self.generate_device = generate_device
+ self.prefill_device = prefill_device
+
+ def load(self):
+ self.orig_module.__init__(
+ self.orig_module.config,
+ device = self.generate_device,
+ )
+
+ @torch.no_grad()
+ def forward(self, x, position_ids):
+ if "dynamic" in self.rope_type:
+ self._dynamic_frequency_update(position_ids, device=x.device)
+ # Core RoPE block
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ # print(inv_freq_expanded.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
\ No newline at end of file
diff --git a/ktransformers/operators/balance_serve_attention.py b/ktransformers/operators/balance_serve_attention.py
index 51695f3..d493329 100644
--- a/ktransformers/operators/balance_serve_attention.py
+++ b/ktransformers/operators/balance_serve_attention.py
@@ -9,6 +9,8 @@ from torch import nn
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeAttention
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention
+from ktransformers.models.modeling_smallthinker import SmallthinkerAttention
+from ktransformers.models.modeling_glm4_moe import Glm4MoeAttention
from typing import Optional, Tuple
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_loader import GGUFLoader
@@ -454,4 +456,191 @@ class deepseek_torch_attn(BaseInjectedModule, DeepseekV2Attention):
attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output, batch_num_tokens_tensors)
final_attention_output = torch.cat((final_attention_output, attn_output), dim=0)
- return final_attention_output
\ No newline at end of file
+ return final_attention_output
+
+class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention):
+ def __init__(self,
+ key: str,
+ gguf_loader : GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module,
+ prefill_device: str = "cuda",
+ generate_device: str = "cuda",
+ chunck_size: int = 1000,
+ **kwargs):
+ BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
+ self.orig_module.__init__(orig_module.config,
+ orig_module.layer_idx)
+ self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
+
+ def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+ def forward(self,
+ hidden_states: torch.Tensor,
+ kv_cache: KGQACache,
+ freqs_cis: torch.Tensor,
+ wrapper: flashInferAttn,
+ bsz_tensors: torch.Tensor,
+ position_ids: torch.Tensor = None,
+ ):
+
+ if self.use_qk_norm:
+ raise NotImplementedError("use_qk_norm is not implemented yet")
+
+ q_len, _ = hidden_states.size()
+ query_states = self.q_proj(hidden_states, bsz_tensors)
+ key_states = self.k_proj(hidden_states, bsz_tensors)
+ value_states = self.v_proj(hidden_states, bsz_tensors)
+
+ query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
+ key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
+ value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
+
+ # cos, sin = freqs_cis
+ """
+ print(query_states.shape)
+ print(key_states.shape)
+ print(cos.shape)
+ print(sin.shape)
+ """
+ if freqs_cis:
+ cos, sin = freqs_cis
+ query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2)
+
+
+
+ query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
+ key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
+ value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
+
+ k_cache = kv_cache.get_k_cache(self.layer_idx)
+ v_cache = kv_cache.get_v_cache(self.layer_idx)
+
+
+ attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
+
+
+ attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors)
+
+ return attn_output
+
+
+
+
+class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):
+ def __init__(self,
+ key: str,
+ gguf_loader : GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module,
+ prefill_device: str = "cuda",
+ generate_device: str = "cuda",
+ chunck_size: int = 1000,
+ **kwargs):
+ BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
+ self.orig_module.__init__(orig_module.config,
+ orig_module.layer_idx)
+ self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
+
+ def apply_rotary_pos_emb(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor],
+ unsqueeze_dim=2
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+
+ # Keep half or full tensor for later concatenation
+ cos = freqs_cis[0]
+ sin = freqs_cis[1]
+ rotary_dim = cos.shape[-1]
+
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
+
+ # Apply rotary embeddings on the first half or full tensor
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
+
+ # Concatenate back to full shape
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
+ return q_embed, k_embed
+
+ def forward(self,
+ hidden_states: torch.Tensor,
+ kv_cache: KGQACache,
+ freqs_cis: torch.Tensor,
+ wrapper: flashInferAttn,
+ bsz_tensors: torch.Tensor,
+ position_ids: torch.Tensor = None,
+ ):
+
+ q_len, _ = hidden_states.size()
+ query_states = self.q_proj(hidden_states, bsz_tensors)
+ key_states = self.k_proj(hidden_states, bsz_tensors)
+ value_states = self.v_proj(hidden_states, bsz_tensors)
+
+
+ if self.use_qk_norm:
+ query_states = self.q_norm(query_states, bsz_tensors)
+ key_states = self.k_norm(key_states, bsz_tensors)
+
+
+ query_states = query_states.view(q_len, self.config.num_attention_heads, self.head_dim)
+ key_states = key_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
+ value_states = value_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
+
+ # cos, sin = freqs_cis
+ """
+ print(query_states.shape)
+ print(key_states.shape)
+ print(cos.shape)
+ print(sin.shape)
+ """
+ if freqs_cis is not None:
+ query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
+
+
+
+ query_states = query_states.view(q_len, self.config.num_attention_heads, self.head_dim)
+ key_states = key_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
+ value_states = value_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
+
+ k_cache = kv_cache.get_k_cache(self.layer_idx)
+ v_cache = kv_cache.get_v_cache(self.layer_idx)
+
+
+ attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
+
+
+ attn_output = self.o_proj(attn_output.view(q_len, self.config.num_attention_heads * self.head_dim), bsz_tensors)
+
+ return attn_output
\ No newline at end of file
diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py
index 7a40168..f4131d1 100644
--- a/ktransformers/operators/experts.py
+++ b/ktransformers/operators/experts.py
@@ -194,6 +194,7 @@ class KExpertsCPU(KExpertsBase):
64,
10,
1024,
+ self.config.hidden_act == 'silu',
gate_ptr,
up_ptr,
down_ptr,
@@ -214,6 +215,7 @@ class KExpertsCPU(KExpertsBase):
self.config.hidden_size,
self.config.moe_intermediate_size,
max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size,
+ self.config.hidden_act == 'silu',
gate_ptr,
up_ptr,
down_ptr,
@@ -232,6 +234,7 @@ class KExpertsCPU(KExpertsBase):
self.config.hidden_size,
self.config.moe_intermediate_size,
max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size,
+ self.config.hidden_act == 'silu',
gate_ptr,
up_ptr,
down_ptr,
@@ -729,6 +732,8 @@ from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock
+from ktransformers.models.modeling_smallthinker import SmallthinkerMoeBlock
+from ktransformers.models.modeling_glm4_moe import Glm4MoeMoE
class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
@@ -1248,6 +1253,12 @@ class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
+
+ if prefill_op == 'None':
+ prefill_op = None
+ if generate_op == 'None':
+ generate_op = None
+
if generate_op is not None:
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
else:
@@ -1307,6 +1318,152 @@ class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):
else:
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
+
+class KSmallthinkerExperts(BaseInjectedModule, KExpertsBase):
+ def __init__(self,
+ key: str,
+ gguf_loader: GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module,
+ # device: str = "cuda",
+ prefill_device:str = "cuda",
+ prefill_op: str | None = "KExpertsTorch",
+ generate_device: str = "cpu",
+ generate_op: str | None = "KExpertsCPU",
+ **kwargs):
+
+ BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
+ KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
+ if generate_op is not None:
+ self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
+ else:
+ self.generate_experts = None
+ if prefill_op is not None:
+ self.prefill_experts = None
+ self.gpu_mlp_type = prefill_op
+ self.cpu_mlp_type = generate_op
+ self.mode = InferenceState.UNLOAD
+
+ def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True):
+ # TODO support w as input
+ if not mode: mode = InferenceState.GENERATE
+ if mode == InferenceState.GENERATE:
+ # self.prefill_experts.unload()
+ self.generate_experts.load(w, warmup=warmup)
+ self.device = self.generate_experts.device
+ self.mode = mode
+ elif mode == InferenceState.PREFILL:
+ self.generate_experts.unload()
+ self.prefill_experts.load(w, warmup=warmup)
+ self.device = self.prefill_experts.device
+ self.mode = mode
+ elif mode == InferenceState.UNLOAD:
+ self.unload()
+ self.mode = mode
+ self.device = self.generate_experts.device
+ else:
+ raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
+
+ def unload(self):
+ if self.generate_experts is not None:
+ self.generate_experts.unload()
+ if self.prefill_experts is not None:
+ self.prefill_experts.unload()
+ self.device = self.generate_experts.device
+
+ def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0):
+ if self.mode == InferenceState.GENERATE:
+ assert self.generate_experts is not None, "generate_experts is None"
+ return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
+ elif self.mode == InferenceState.PREFILL:
+ assert self.prefill_experts is not None, "prefill_experts is None"
+ return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
+ else:
+ raise ValueError("load or set_inference_mode before forward")
+
+ def set_inference_mode(self, mode: InferenceState):
+ if mode == InferenceState.GENERATE:
+ self.load(mode=InferenceState.GENERATE, warmup=False)
+ elif mode == InferenceState.PREFILL:
+ self.load(mode=InferenceState.PREFILL, warmup=False)
+ elif mode == InferenceState.UNLOAD:
+ self.unload()
+ else:
+ raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
+
+class KGlm4Experts(BaseInjectedModule, KExpertsBase):
+ def __init__(self,
+ key: str,
+ gguf_loader: GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module,
+ # device: str = "cuda",
+ prefill_device:str = "cuda",
+ prefill_op: str | None = "KExpertsTorch",
+ generate_device: str = "cpu",
+ generate_op: str | None = "KExpertsCPU",
+ **kwargs):
+
+ BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
+ KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
+ if generate_op is not None:
+ self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
+ else:
+ self.generate_experts = None
+ if prefill_op is not None:
+ self.prefill_experts = None
+ self.gpu_mlp_type = prefill_op
+ self.cpu_mlp_type = generate_op
+ self.mode = InferenceState.UNLOAD
+
+ def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True):
+ # TODO support w as input
+ if not mode: mode = InferenceState.GENERATE
+ if mode == InferenceState.GENERATE:
+ # self.prefill_experts.unload()
+ self.generate_experts.load(w, warmup=warmup)
+ self.device = self.generate_experts.device
+ self.mode = mode
+ elif mode == InferenceState.PREFILL:
+ self.generate_experts.unload()
+ self.prefill_experts.load(w, warmup=warmup)
+ self.device = self.prefill_experts.device
+ self.mode = mode
+ elif mode == InferenceState.UNLOAD:
+ self.unload()
+ self.mode = mode
+ self.device = self.generate_experts.device
+ else:
+ raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
+
+ def unload(self):
+ if self.generate_experts is not None:
+ self.generate_experts.unload()
+ if self.prefill_experts is not None:
+ self.prefill_experts.unload()
+ self.device = self.generate_experts.device
+
+ def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0):
+ if self.mode == InferenceState.GENERATE:
+ assert self.generate_experts is not None, "generate_experts is None"
+ return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
+ elif self.mode == InferenceState.PREFILL:
+ assert self.prefill_experts is not None, "prefill_experts is None"
+ return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
+ else:
+ raise ValueError("load or set_inference_mode before forward")
+
+ def set_inference_mode(self, mode: InferenceState):
+ if mode == InferenceState.GENERATE:
+ self.load(mode=InferenceState.GENERATE, warmup=False)
+ elif mode == InferenceState.PREFILL:
+ self.load(mode=InferenceState.PREFILL, warmup=False)
+ elif mode == InferenceState.UNLOAD:
+ self.unload()
+ else:
+ raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
+
+
class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
@@ -1507,6 +1664,246 @@ class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock):
)
return outs
+ @torch.no_grad()
+ # TODO may bugs here
+ def moe_infer(self, x, topk_ids, topk_weight):
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
+ cnts.scatter_(1, topk_ids, 1)
+ tokens_per_expert = cnts.sum(dim=0)
+ idxs = topk_ids.view(-1).argsort()
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
+ tokens_per_expert = tokens_per_expert.cpu().numpy()
+
+ outputs = []
+ start_idx = 0
+ for i, num_tokens in enumerate(tokens_per_expert):
+ end_idx = start_idx + num_tokens
+ if num_tokens == 0:
+ continue
+ expert = self.experts[i + self.ep_rank * self.experts_per_rank]
+ tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
+ expert_out = expert.forward(tokens_for_this_expert)
+ outputs.append(expert_out)
+ start_idx = end_idx
+
+ outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
+
+ new_x = torch.empty_like(outs)
+ new_x[idxs] = outs
+ final_out = (
+ new_x.view(*topk_ids.shape, -1)
+ .type(topk_weight.dtype)
+ .mul_(topk_weight.unsqueeze(dim=-1))
+ .sum(dim=1)
+ .type(new_x.dtype)
+ )
+ return final_out
+
+
+class KSmallthinkerMoeBlock(BaseInjectedModule, SmallthinkerMoeBlock):
+ def forward(self, router_input: torch.Tensor, hidden_states: torch.Tensor, bsz_tensor=None, cuda_graph_idx=0):
+
+ orig_shape = hidden_states.shape
+ sequence_length = orig_shape[1]
+
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+
+ if bsz_tensor is None:
+ if self.enable_early_router:
+ router_logits = self.primary_router(router_input)
+ else:
+ router_logits = self.primary_router(hidden_states)
+ else:
+ if self.enable_early_router:
+ router_logits = self.primary_router(router_input, bsz_tensor)
+ else:
+ router_logits = self.primary_router(hidden_states, bsz_tensor)
+
+ router_logits, selected_experts = torch.topk(router_logits, self.num_active_primary_experts, dim=-1)
+
+
+ if router_logits.device.type == "xpu":
+ # TODO: support self.moe_primary_router_apply_softmax False case
+ from ipex_llm.transformers.models.common import moe_softmax_topk
+ selected_experts, routing_weights = moe_softmax_topk(
+ router_logits.half(), self.top_k, self.norm_topk_prob
+ )
+ else:
+ if self.moe_primary_router_apply_softmax:
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
+ else:
+ routing_weights = F.sigmoid(router_logits)
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+ # we cast back to the input dtype
+ routing_weights = routing_weights.to(hidden_states.dtype)
+
+ # only for generate phase
+ if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
+ self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
+ # y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
+ # y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
+
+ y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
+
+ # y += y_
+ y.resize_(*orig_shape)
+ return y
+
+ # y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
+ # y_ = (
+ # F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
+ # )
+
+
+ if isinstance(self.experts, KExpertsBase):
+ y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
+ elif hidden_states.size(0) > 10:
+ # TODO may bugs here
+ y = (
+ self.moe_infer(hidden_states, selected_experts, routing_weights)
+ .view(*orig_shape)
+ .to(device=hidden_states.device)
+ )
+ else:
+ # TODO may bugs here
+ y = (
+ self.moe_infer_simple(hidden_states, selected_experts, routing_weights)
+ .view(*orig_shape)
+ .to(device=hidden_states.device)
+ )
+ # y += y_
+ return y
+
+ @torch.no_grad()
+ def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
+ outs = torch.empty_like(x)
+ outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
+ return outs
+
+ @torch.no_grad()
+ # TODO may bugs here
+ def moe_infer_simple(
+ self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ x: [num_tokens, hidden_size]
+ topk_ids, topk_weight: [num_tokens, num_selected_experts]
+ """
+ outs = torch.zeros_like(x)
+ for token_idx in range(topk_ids.size(0)):
+ for expert_idx in range(topk_ids.size(1)):
+ expert = self.experts[topk_ids[token_idx, expert_idx]]
+ outs[token_idx] += (
+ expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
+ )
+ return outs
+
+ @torch.no_grad()
+ # TODO may bugs here
+ def moe_infer(self, x, topk_ids, topk_weight):
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
+ cnts.scatter_(1, topk_ids, 1)
+ tokens_per_expert = cnts.sum(dim=0)
+ idxs = topk_ids.view(-1).argsort()
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
+ tokens_per_expert = tokens_per_expert.cpu().numpy()
+
+ outputs = []
+ start_idx = 0
+ for i, num_tokens in enumerate(tokens_per_expert):
+ end_idx = start_idx + num_tokens
+ if num_tokens == 0:
+ continue
+ expert = self.experts[i + self.ep_rank * self.experts_per_rank]
+ tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
+ expert_out = expert.forward(tokens_for_this_expert)
+ outputs.append(expert_out)
+ start_idx = end_idx
+
+ outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
+
+ new_x = torch.empty_like(outs)
+ new_x[idxs] = outs
+ final_out = (
+ new_x.view(*topk_ids.shape, -1)
+ .type(topk_weight.dtype)
+ .mul_(topk_weight.unsqueeze(dim=-1))
+ .sum(dim=1)
+ .type(new_x.dtype)
+ )
+ return final_out
+
+
+class KGlm4MoeMoE(BaseInjectedModule, Glm4MoeMoE):
+ def forward(self, hidden_states, bsz_tensor=None, cuda_graph_idx=0):
+
+ orig_shape = hidden_states.shape
+ sequence_length = orig_shape[1]
+
+ topk_idx, topk_weight = self.gate(hidden_states)
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+
+ # only for generate phase
+ if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
+ self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx)
+ y_ = self.shared_experts(hidden_states, bsz_tensor).squeeze(0)
+ # y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
+
+ y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
+
+ y += y_
+ y.resize_(*orig_shape)
+ return y
+
+ y_ = self.shared_experts(hidden_states, bsz_tensor).squeeze(0)
+ # y_ = (
+ # F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
+ # )
+
+
+ if isinstance(self.experts, KExpertsBase):
+ y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
+ elif hidden_states.size(0) > 10:
+ # TODO may bugs here
+ y = (
+ self.moe_infer(hidden_states, topk_idx, topk_weight)
+ .view(*orig_shape)
+ .to(device=hidden_states.device)
+ )
+ else:
+ # TODO may bugs here
+ y = (
+ self.moe_infer_simple(hidden_states, topk_idx, topk_weight)
+ .view(*orig_shape)
+ .to(device=hidden_states.device)
+ )
+ y += y_
+ return y
+
+ @torch.no_grad()
+ def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
+ outs = torch.empty_like(x)
+ outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
+ return outs
+
+ @torch.no_grad()
+ # TODO may bugs here
+ def moe_infer_simple(
+ self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ x: [num_tokens, hidden_size]
+ topk_ids, topk_weight: [num_tokens, num_selected_experts]
+ """
+ outs = torch.zeros_like(x)
+ for token_idx in range(topk_ids.size(0)):
+ for expert_idx in range(topk_ids.size(1)):
+ expert = self.experts[topk_ids[token_idx, expert_idx]]
+ outs[token_idx] += (
+ expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
+ )
+ return outs
+
@torch.no_grad()
# TODO may bugs here
def moe_infer(self, x, topk_ids, topk_weight):
diff --git a/ktransformers/operators/gate.py b/ktransformers/operators/gate.py
index f5f96c1..b9ccc01 100644
--- a/ktransformers/operators/gate.py
+++ b/ktransformers/operators/gate.py
@@ -212,4 +212,5 @@ class KMoEGateIPEXLLM(KMoEGate):
topk_idx, topk_weight = moe_group_topk(scores, self.orig_module.e_score_correction_bias,
self.n_group, self.topk_group, self.top_k,
self.norm_topk_prob, self.routed_scaling_factor)
- return topk_idx, topk_weight.to(x.dtype)
\ No newline at end of file
+ return topk_idx, topk_weight.to(x.dtype)
+
diff --git a/ktransformers/operators/layernorm.py b/ktransformers/operators/layernorm.py
index 796592c..24bdc81 100644
--- a/ktransformers/operators/layernorm.py
+++ b/ktransformers/operators/layernorm.py
@@ -28,6 +28,8 @@ import torch.nn as nn
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3RMSNorm
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeRMSNorm
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm
+from ktransformers.models.modeling_smallthinker import SmallthinkerRMSNorm
+from ktransformers.models.modeling_glm4_moe import Glm4MoeRMSNorm
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_loader import GGUFLoader
if not torch.xpu.is_available():
@@ -164,6 +166,94 @@ class KQwen3MoeRMSNorm(Qwen3MoeRMSNorm, BaseInjectedModule):
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
+
+class KSmallthinkerRMSNorm(SmallthinkerRMSNorm, BaseInjectedModule):
+ def __init__(self,
+ key: str,
+ gguf_loader : GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module,
+ prefill_device: str = "cuda",
+ generate_device: str = "cuda",
+ **kwargs):
+ BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
+ self.orig_module.__init__(orig_module.hidden_size,
+ orig_module.variance_epsilon)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ batch_size_tensor: torch.Tensor = None,
+ residual: Optional[torch.Tensor] = None,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ #return self.forward_native(x, residual)
+ bsz, hidden_size = x.shape
+ x = x.view(-1, self.orig_module.hidden_size)
+ if batch_size_tensor is None:
+ return self.forward_native(x)
+ if residual is not None:
+ fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
+ #residual = x + residual
+ #out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
+ return x, residual
+ # print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
+ out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)
+ out = out.view(bsz, hidden_size)
+ return out
+
+ def forward_native(
+ self, hidden_states
+ ):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+class KGlm4MoeRMSNorm(Glm4MoeRMSNorm, BaseInjectedModule):
+ def __init__(self,
+ key: str,
+ gguf_loader : GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module,
+ prefill_device: str = "cuda",
+ generate_device: str = "cuda",
+ **kwargs):
+ BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
+ self.orig_module.__init__(orig_module.hidden_size,
+ orig_module.variance_epsilon)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ batch_size_tensor: torch.Tensor = None,
+ residual: Optional[torch.Tensor] = None,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ #return self.forward_native(x, residual)
+ bsz, hidden_size = x.shape
+ x = x.view(-1, self.orig_module.hidden_size)
+ if batch_size_tensor is None:
+ return self.forward_native(x)
+ if residual is not None:
+ fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
+ #residual = x + residual
+ #out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
+ return x, residual
+ # print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
+ out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)
+ out = out.view(bsz, hidden_size)
+ return out
+
+ def forward_native(
+ self, hidden_states
+ ):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
class DeepseekV3RMSNormTorch(DeepseekV3RMSNorm, BaseInjectedModule):
def __init__(self,
diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py
index 654c9f9..e232f83 100644
--- a/ktransformers/operators/linear.py
+++ b/ktransformers/operators/linear.py
@@ -88,10 +88,17 @@ class KLinearBase(ABC):
if isinstance(self.gguf_loader, SafeTensorLoader):
# using safetensor_loader
tensor = self.gguf_loader.load_tensor(key+'.weight')
+ try:
+ bias = self.gguf_loader.load_tensor(key+'.bias')
+ except:
+ bias = None
if self.gguf_loader.has_tensor(key+'.weight_scale_inv'):
weight_scale_inv = self.gguf_loader.load_tensor(key+'.weight_scale_inv')
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
- return nn.Parameter(tensor)
+ if bias is not None:
+ return nn.Parameter(tensor), nn.Parameter(bias)
+ else:
+ return nn.Parameter(tensor)
elif self.gguf_loader.has_tensor(key + ".weight") or "kv_b_proj" in key:
if key + ".bias" in self.gguf_loader.tensor_file_map:
diff --git a/ktransformers/operators/mlp.py b/ktransformers/operators/mlp.py
index 77d7d05..6d3e812 100644
--- a/ktransformers/operators/mlp.py
+++ b/ktransformers/operators/mlp.py
@@ -5,6 +5,8 @@ from transformers import PretrainedConfig
import torch.nn as nn
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeMLP
+from ktransformers.models.modeling_smallthinker import SmallthinkerDenseMlpBlock
+from ktransformers.models.modeling_glm4_moe import Glm4MoeMLP
class kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule):
def __init__(self,
key: str,
@@ -32,6 +34,37 @@ class KQwen2MoeMLP(Qwen2MoeMLP, BaseInjectedModule):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.intermediate_size)
+ def forward(self, x, bsz_tensor):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)
+ return down_proj
+
+
+class KSmallthinkerDenseMlpBlock(SmallthinkerDenseMlpBlock, BaseInjectedModule):
+ def __init__(self,
+ key: str,
+ gguf_loader : GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module,
+ prefill_device: str = "cuda",
+ generate_device: str = "cuda",
+ **kwargs):
+ BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
+ self.orig_module.__init__(orig_module.config)
+ def forward(self, x, bsz_tensor):
+ down_proj = self.down(nn.functional.relu(self.gate(x, bsz_tensor)) * self.up(x, bsz_tensor), bsz_tensor)
+ return down_proj
+
+class KGlm4MoeMLP(Glm4MoeMLP, BaseInjectedModule):
+ def __init__(self,
+ key: str,
+ gguf_loader : GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module,
+ prefill_device: str = "cuda",
+ generate_device: str = "cuda",
+ **kwargs):
+ BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
+ self.orig_module.__init__(orig_module.config, orig_module.hidden_size, orig_module.intermediate_size)
def forward(self, x, bsz_tensor):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)
return down_proj
\ No newline at end of file
diff --git a/ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml b/ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml
new file mode 100644
index 0000000..56345df
--- /dev/null
+++ b/ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml
@@ -0,0 +1,90 @@
+- match:
+ class: ktransformers.models.modeling_glm4_moe.Glm4MoeRotaryEmbedding
+ replace:
+ class: ktransformers.operators.RoPE.KGlm4MoeRotaryEmbedding
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+
+- match:
+ name: "^lm_head$" # regular expression
+ class: torch.nn.Linear # only match modules matching name and class simultaneously
+ replace:
+ class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+ generate_op: "VLinearMarlin"
+ prefill_op: "KLinearTorch"
+
+# - match:
+# name: "^model\\.layers\\..*$" # regular expression
+# class: torch.nn.Linear # only match modules matching name and class simultaneously
+# replace:
+# class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
+# kwargs:
+# generate_device: "cuda"
+# prefill_device: "cuda"
+# generate_op: "VLinearMarlin"
+# prefill_op: "KLinearTorch"
+- match:
+ name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression
+ class: torch.nn.Linear # only match modules matching name and class simultaneously
+ replace:
+ class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+ generate_op: "KLinearMarlin"
+ prefill_op: "KLinearTorch"
+- match:
+ name: "^model\\.layers\\..*\\.mlp$"
+ class: ktransformers.models.modeling_glm4_moe.Glm4MoeMoE
+ replace:
+ class: ktransformers.operators.experts.KGlm4MoeMoE
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+
+- match:
+ name: "^model\\.layers\\..*\\.mlp\\.experts$"
+ replace:
+ class: ktransformers.operators.experts.KGlm4Experts # custom MoE Kernel with expert paralleism
+ kwargs:
+ prefill_device: "cuda"
+ prefill_op: None
+ generate_device: "cpu"
+ generate_op: "KExpertsCPU"
+ out_device: "cuda"
+ recursive: False # don't recursively inject submodules of this module
+- match:
+ name: "^model\\.layers\\..*\\.self_attn$"
+ replace:
+ class: ktransformers.operators.balance_serve_attention.KGlm4MoeAttention # optimized MLA implementation
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+
+- match:
+ name: "^model.embed_tokens"
+ replace:
+ class: "default"
+ kwargs:
+ generate_device: "cpu"
+ prefill_device: "cpu"
+
+- match:
+ class: ktransformers.models.modeling_glm4_moe.Glm4MoeRMSNorm
+ replace:
+ class: ktransformers.operators.layernorm.KGlm4MoeRMSNorm
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+
+- match:
+ class: ktransformers.models.modeling_glm4_moe.Glm4MoeMLP
+ replace:
+ class: ktransformers.operators.mlp.KGlm4MoeMLP
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
\ No newline at end of file
diff --git a/ktransformers/server/args.py b/ktransformers/server/args.py
index 748bd47..ecf1f84 100644
--- a/ktransformers/server/args.py
+++ b/ktransformers/server/args.py
@@ -2,6 +2,9 @@ import argparse
from ktransformers.server.backend.args import ConfigArgs, default_args
from ktransformers.util.utils import get_free_ports
from transformers import AutoConfig
+from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig
+from ktransformers.models.configuration_smallthinker import SmallthinkerConfig
+from ktransformers.models.configuration_glm4_moe import Glm4MoeConfig
class ArgumentParser:
def __init__(self, cfg):
@@ -135,9 +138,16 @@ class ArgumentParser:
self.cfg.server_ip = args.host
self.cfg.server_port = args.port
self.cfg.user_force_think = args.force_think
-
- model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
- if model_config.architectures[0] == "Qwen3MoeForCausalLM" or model_config.architectures[0] == "Qwen2MoeForCausalLM" :
+ try:
+ model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
+ except:
+ try:
+ model_config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
+ except:
+ raise ValueError(f"Model {args.model_name} not supported. Please check your model directory or model name.")
+
+
+ if model_config.architectures[0] == "Qwen3MoeForCausalLM" or model_config.architectures[0] == "Qwen2MoeForCausalLM" or model_config.architectures[0] == "SmallThinkerForCausalLM" or model_config.architectures[0] == "Glm4MoeForCausalLM":
args.gpu_memory_size = args.cache_lens*2*2*model_config.num_hidden_layers*model_config.num_key_value_heads*model_config.head_dim
args.architectures = model_config.architectures[0]
else:
diff --git a/ktransformers/server/backend/interfaces/balance_serve.py b/ktransformers/server/backend/interfaces/balance_serve.py
index d3d6dcf..045f7a1 100644
--- a/ktransformers/server/backend/interfaces/balance_serve.py
+++ b/ktransformers/server/backend/interfaces/balance_serve.py
@@ -24,7 +24,11 @@ from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausa
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM
from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM
+from ktransformers.models.custom_modeling_smallthinker import KSmallThinkerForCausalLM
+from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM
from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig
+from ktransformers.models.configuration_smallthinker import SmallthinkerConfig
+from ktransformers.models.configuration_glm4_moe import Glm4MoeConfig
from ktransformers.server.balance_serve.inference.model_runner import ModelRunner
from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
@@ -60,6 +64,8 @@ default_optimize_rules = {
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-serve.yaml",
"Qwen3MoeForCausalLM": ktransformer_rules_dir + "Qwen3Moe-serve.yaml",
+ "SmallThinkerForCausalLM": ktransformer_rules_dir + "Smallthinker-serve.yaml",
+ "Glm4MoeForCausalLM": ktransformer_rules_dir + "Glm4Moe-serve.yaml",
}
@@ -123,15 +129,25 @@ class Engine:
self.sched_client = SchedulerClient(args.sched_port)
self.updates = []
- try:
- config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
- except:
- if args.model_name == "Qwen3Moe":
- config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
- else:
- assert False, f"model {args.model_name} not supported"
+ print(f"args.architectures: {args.architectures}")
+
+ if args.architectures == "Qwen3MoeForCausalLM":
+ config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
+ elif args.architectures == "Glm4MoeForCausalLM":
+ config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
+ elif args.architectures == "SmallThinkerForCausalLM":
+ config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
+ config._attn_implementation = "eager"
+ config.moe_intermediate_size = config.moe_ffn_hidden_size
+ else:
+ try:
+ config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
+ except:
+ raise ValueError(f"Model {args.architectures} not supported. Please check your model directory or model name.")
+
+
+
-
self.gen_queue = generated_token_queue
with torch.device("meta"):
@@ -147,6 +163,13 @@ class Engine:
self.model = KQwen2MoeForCausalLM(config, self.cache)
else:
self.model = KQwen3MoeForCausalLM(config, self.cache)
+ elif config.architectures[0] == "SmallThinkerForCausalLM":
+ self.cache = KGQACache(config, self.args.page_size)
+ self.model = KSmallThinkerForCausalLM(config, self.cache)
+ elif config.architectures[0] == "Glm4MoeForCausalLM":
+ self.cache = KGQACache(config, self.args.page_size)
+ self.model = KGlm4MoeForCausalLM(config, self.cache)
+
context = zmq.Context()
@@ -197,7 +220,7 @@ class Engine:
self.block_num = inference_context.k_cache[0].size(1)
self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
#@TODO add config
- if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM":
+ if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallThinkerForCausalLM":
self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num)
else:
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
diff --git a/ktransformers/server/balance_serve/inference/model_runner.py b/ktransformers/server/balance_serve/inference/model_runner.py
index 55dfb6d..75fb169 100644
--- a/ktransformers/server/balance_serve/inference/model_runner.py
+++ b/ktransformers/server/balance_serve/inference/model_runner.py
@@ -29,6 +29,8 @@ from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausa
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM
from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM
+from ktransformers.models.custom_modeling_smallthinker import KSmallThinkerForCausalLM
+from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
from ktransformers.server.balance_serve.settings import sched_ext
@@ -53,7 +55,7 @@ def generate_cuda_graphs(chunk_size: int) -> list:
class ModelRunner:
"""A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile."""
- model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM
+ model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM | KSmallThinkerForCausalLM | KGlm4MoeForCausalLM
input: ForwardBatchInput | list[ForwardBatchInput]
output: ForwardBatchOutput
@@ -93,7 +95,7 @@ class ModelRunner:
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
- elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM):
+ elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallThinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM):
self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
num_q_heads=self.model.config.num_attention_heads, num_kv_heads=self.model.config.num_key_value_heads,
head_dim=self.model.config.head_dim if hasattr(self.model.config, 'head_dim') else self.model.config.hidden_size // self.model.config.num_attention_heads,
@@ -124,7 +126,7 @@ class ModelRunner:
num_tokens = self.features_buf[i][0].size(0)
print("capturing cuda graph", batch_size, num_tokens)
- if isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM):
+ if isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallThinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM):
self.model.init_wrapper(self.use_cuda_graph, self.device, num_tokens ,batch_size, self.block_num, i) # TODO: 1024 is a magic number(max_batch_tokens)
self.bsz_tensor_buf[0] = batch_size
diff --git a/ktransformers/server/balance_serve/sched_rpc.py b/ktransformers/server/balance_serve/sched_rpc.py
index 218d1d3..51556f8 100644
--- a/ktransformers/server/balance_serve/sched_rpc.py
+++ b/ktransformers/server/balance_serve/sched_rpc.py
@@ -10,7 +10,7 @@ current_file_path = os.path.abspath(__file__)
# sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
import pickle
import argparse
-from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings, create_sched_settings_qwen2moe, create_sched_settings_qwen3moe
+from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings, create_sched_settings_qwen2moe, create_sched_settings_qwen3moe, create_sched_settings_glm4moe, create_sched_settings_smallthinker
@@ -213,6 +213,10 @@ if __name__ == '__main__':
settings = create_sched_settings_qwen2moe(main_args)
elif main_args.architectures == "Qwen3MoeForCausalLM":
settings = create_sched_settings_qwen3moe(main_args)
+ elif main_args.architectures == "Glm4MoeForCausalLM":
+ settings = create_sched_settings_glm4moe(main_args)
+ elif main_args.architectures == "SmallThinkerForCausalLM":
+ settings = create_sched_settings_smallthinker(main_args)
else:
settings = create_sched_settings(main_args)
start_server(settings, main_args)
diff --git a/ktransformers/server/balance_serve/settings.py b/ktransformers/server/balance_serve/settings.py
index 40a29bf..0d62df1 100644
--- a/ktransformers/server/balance_serve/settings.py
+++ b/ktransformers/server/balance_serve/settings.py
@@ -12,6 +12,8 @@ import sched_ext
from transformers import AutoConfig
from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig
+from ktransformers.models.configuration_glm4_moe import Glm4MoeConfig
+from ktransformers.models.configuration_smallthinker import SmallthinkerConfig
def create_sched_settings(args):
default_sample_options = sched_ext.SampleOptions()
@@ -172,6 +174,110 @@ def create_sched_settings_qwen3moe(args):
settings.auto_derive()
return settings
+def create_sched_settings_glm4moe(args):
+ default_sample_options = sched_ext.SampleOptions()
+ model_name = os.path.basename(os.path.normpath(args.model_dir))
+ input_model_settings = sched_ext.ModelSettings()
+ input_model_settings.model_path = args.model_dir
+ input_model_settings.params_count = int(0)
+ model_config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
+ input_model_settings.layer_count = model_config.num_hidden_layers
+ input_model_settings.num_k_heads = model_config.num_key_value_heads # model_config["num_key_value_heads"]
+ input_model_settings.k_head_dim = 128
+ input_model_settings.bytes_per_params = 2
+ input_model_settings.bytes_per_kv_cache_element = 2
+ settings = sched_ext.Settings()
+ settings.model_name = model_name
+ settings.quant_type = "BF16"
+ settings.model_settings = input_model_settings
+ settings.page_size = args.page_size
+ settings.gpu_device_count = 1 # tp
+ settings.gpu_device_id = [i for i in range(settings.gpu_device_count)]
+ # settings.gpu_memory_size = args.cache_lens*576*2
+ settings.gpu_memory_size = args.gpu_memory_size
+ settings.memory_utilization_percentage = args.utilization_percentage
+ max_batch_size = args.max_batch_size
+ chunk_size = args.chunk_size
+
+ max_decode_batch_size = max_batch_size - 2
+
+ settings.max_batch_size = max_batch_size
+ settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2
+ settings.sample_options = default_sample_options
+ settings.sched_metrics_port = args.sched_metrics_port
+ settings.gpu_only = args.memory_gpu_only
+ settings.use_self_defined_head_dim = False
+ settings.self_defined_head_dim = 576
+ settings.full_kv_cache_on_each_gpu = True
+ settings.k_cache_on = True
+ settings.v_cache_on = True
+
+ settings.kvc2_root_path = args.kvc2_disk_path
+ settings.kvc2_config_path = args.kvc2_config_dir
+ settings.memory_pool_size_GB = args.cpu_memory_size_GB
+ settings.evict_count = 40
+ settings.kvc2_metrics_port = args.kvc2_metrics_port
+ settings.load_from_disk = False
+ settings.save_to_disk = True
+
+
+ settings.strategy_name = args.sched_strategy
+
+ settings.auto_derive()
+ return settings
+
+def create_sched_settings_smallthinker(args):
+ default_sample_options = sched_ext.SampleOptions()
+ model_name = os.path.basename(os.path.normpath(args.model_dir))
+ input_model_settings = sched_ext.ModelSettings()
+ input_model_settings.model_path = args.model_dir
+ input_model_settings.params_count = int(0)
+ model_config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
+ input_model_settings.layer_count = model_config.num_hidden_layers
+ input_model_settings.num_k_heads = model_config.num_key_value_heads # model_config["num_key_value_heads"]
+ input_model_settings.k_head_dim = 128
+ input_model_settings.bytes_per_params = 2
+ input_model_settings.bytes_per_kv_cache_element = 2
+ settings = sched_ext.Settings()
+ settings.model_name = model_name
+ settings.quant_type = "BF16"
+ settings.model_settings = input_model_settings
+ settings.page_size = args.page_size
+ settings.gpu_device_count = 1 # tp
+ settings.gpu_device_id = [i for i in range(settings.gpu_device_count)]
+ # settings.gpu_memory_size = args.cache_lens*576*2
+ settings.gpu_memory_size = args.gpu_memory_size
+ settings.memory_utilization_percentage = args.utilization_percentage
+ max_batch_size = args.max_batch_size
+ chunk_size = args.chunk_size
+
+ max_decode_batch_size = max_batch_size - 2
+
+ settings.max_batch_size = max_batch_size
+ settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2
+ settings.sample_options = default_sample_options
+ settings.sched_metrics_port = args.sched_metrics_port
+ settings.gpu_only = args.memory_gpu_only
+ settings.use_self_defined_head_dim = False
+ settings.self_defined_head_dim = 576
+ settings.full_kv_cache_on_each_gpu = True
+ settings.k_cache_on = True
+ settings.v_cache_on = True
+
+ settings.kvc2_root_path = args.kvc2_disk_path
+ settings.kvc2_config_path = args.kvc2_config_dir
+ settings.memory_pool_size_GB = args.cpu_memory_size_GB
+ settings.evict_count = 40
+ settings.kvc2_metrics_port = args.kvc2_metrics_port
+ settings.load_from_disk = False
+ settings.save_to_disk = True
+
+
+ settings.strategy_name = args.sched_strategy
+
+ settings.auto_derive()
+ return settings
+
diff --git a/ktransformers/tests/test_speed.py b/ktransformers/tests/test_speed.py
index 6f435b4..41848c1 100644
--- a/ktransformers/tests/test_speed.py
+++ b/ktransformers/tests/test_speed.py
@@ -149,7 +149,7 @@ if __name__ == "__main__":
parser.add_argument("--model", type=str, default="DeepSeek-V3", help="Model name")
parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048")
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
- parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens")
+ parser.add_argument("--max_tokens", type=int, default=500, help="max decode tokens")
args = parser.parse_args()
SERVER_URL = args.api_url
@@ -161,5 +161,7 @@ if __name__ == "__main__":
prompt = ktansformer_prompt1024 * 2
elif args.prompt_lens == 4096:
prompt = ktansformer_prompt1024 * 4
+
+
asyncio.run(main(args.concurrent, prompt, max_tokens, model))
diff --git a/ktransformers/util/custom_loader.py b/ktransformers/util/custom_loader.py
index 003f93c..ee08e47 100644
--- a/ktransformers/util/custom_loader.py
+++ b/ktransformers/util/custom_loader.py
@@ -138,8 +138,12 @@ class SafeTensorLoader(ModelLoader):
base_key = key # e.g. "model.layers.3.mlp.experts"
experts_count = 0
+ key_no_proj = False
+ if self.has_tensor(f"{base_key}.{experts_count}.up.weight"):
+ key_no_proj = True
+
# First, count how many experts we have by checking for expert 0's up_proj
- while self.has_tensor(f"{base_key}.{experts_count}.up_proj.weight"):
+ while self.has_tensor(f"{base_key}.{experts_count}.up_proj.weight") or self.has_tensor(f"{base_key}.{experts_count}.up.weight"):
experts_count += 1
if experts_count == 0:
@@ -152,9 +156,15 @@ class SafeTensorLoader(ModelLoader):
# Load all expert weights
for expert_id in range(experts_count):
- up_key = f"{base_key}.{expert_id}.up_proj.weight"
- gate_key = f"{base_key}.{expert_id}.gate_proj.weight"
- down_key = f"{base_key}.{expert_id}.down_proj.weight"
+
+ if key_no_proj:
+ up_key = f"{base_key}.{expert_id}.up.weight"
+ gate_key = f"{base_key}.{expert_id}.gate.weight"
+ down_key = f"{base_key}.{expert_id}.down.weight"
+ else:
+ up_key = f"{base_key}.{expert_id}.up_proj.weight"
+ gate_key = f"{base_key}.{expert_id}.gate_proj.weight"
+ down_key = f"{base_key}.{expert_id}.down_proj.weight"
up_tensor = self.load_tensor(up_key, device)
gate_tensor = self.load_tensor(gate_key, device)
diff --git a/pyproject.toml b/pyproject.toml
index 9502c55..514516c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -16,7 +16,7 @@ dynamic = ["version"]
dependencies = [
"torch >= 2.3.0",
- "transformers == 4.51.3",
+ "transformers == 4.53.3",
"fastapi >= 0.111.0",
"uvicorn >= 0.30.1",
"langchain >= 0.2.0",
diff --git a/requirements-local_chat.txt b/requirements-local_chat.txt
index 25afaef..8743136 100644
--- a/requirements-local_chat.txt
+++ b/requirements-local_chat.txt
@@ -1,5 +1,5 @@
fire
-transformers==4.51.3
+transformers==4.53.3
numpy
torch>=2.3.0
packaging