From d03d92ba53e6b23c227b667c3e81ab4614e31f3f Mon Sep 17 00:00:00 2001 From: djw Date: Fri, 25 Jul 2025 17:22:20 +0000 Subject: [PATCH] support glm4moe --- .gitignore | 6 +- README.md | 11 +- csrc/balance_serve/CMakeLists.txt | 8 +- csrc/balance_serve/sched/model_config.h | 6 +- csrc/ktransformers_ext/ext_bindings.cpp | 8 +- csrc/ktransformers_ext/operators/amx/moe.hpp | 50 +- .../operators/llamafile/moe.cpp | 31 +- .../operators/llamafile/moe.h | 5 +- doc/en/SmallThinker_and_Glm4moe.md | 76 ++ ktransformers/configs/config.yaml | 14 +- ktransformers/ktransformers | 1 + .../models/configuration_glm4_moe.py | 242 +++++++ .../models/custom_modeling_glm4_moe.py | 124 ++++ ktransformers/models/modeling_glm4_moe.py | 649 ++++++++++++++++++ ktransformers/operators/RoPE.py | 92 ++- .../operators/balance_serve_attention.py | 191 +++++- ktransformers/operators/experts.py | 397 +++++++++++ ktransformers/operators/gate.py | 3 +- ktransformers/operators/layernorm.py | 90 +++ ktransformers/operators/linear.py | 9 +- ktransformers/operators/mlp.py | 33 + .../optimize_rules/Glm4Moe-serve.yaml | 90 +++ ktransformers/server/args.py | 16 +- .../backend/interfaces/balance_serve.py | 41 +- .../balance_serve/inference/model_runner.py | 8 +- .../server/balance_serve/sched_rpc.py | 6 +- .../server/balance_serve/settings.py | 106 +++ ktransformers/tests/test_speed.py | 4 +- ktransformers/util/custom_loader.py | 18 +- pyproject.toml | 2 +- requirements-local_chat.txt | 2 +- 31 files changed, 2265 insertions(+), 74 deletions(-) create mode 100644 doc/en/SmallThinker_and_Glm4moe.md create mode 120000 ktransformers/ktransformers create mode 100644 ktransformers/models/configuration_glm4_moe.py create mode 100644 ktransformers/models/custom_modeling_glm4_moe.py create mode 100644 ktransformers/models/modeling_glm4_moe.py create mode 100644 ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml diff --git a/.gitignore b/.gitignore index 38bb53c..7cb23b1 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,8 @@ ktransformers/tests/chat_txt.txt mmlu_result* ktransformers/ktransformers_ext/cuda_musa/ test_prompt.txt -csrc/demo \ No newline at end of file +csrc/demo +build* +CMakeFiles/ +kvc2/ +sched/ \ No newline at end of file diff --git a/README.md b/README.md index 4f630f3..117a5e0 100644 --- a/README.md +++ b/README.md @@ -23,19 +23,14 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin

🔥 Updates

+* **July 26, 2025**: Support SmallThinker and GLM4-MoE. ([Tutorial](./doc/en/SmallThinker_and_Glm4moe.md)) * **July 11, 2025**: Support Kimi-K2. ([Tutorial](./doc/en/Kimi-K2.md)) - * **June 30, 2025**: Support 3-layer (GPU-CPU-Disk) [prefix cache](./doc/en/prefix_cache.md) reuse. - * **May 14, 2025**: Support Intel Arc GPU ([Tutorial](./doc/en/xpu.md)). - * **Apr 29, 2025**: Support AMX-Int8、 AMX-BF16 and Qwen3MoE ([Tutorial](./doc/en/AMX.md)) https://github.com/user-attachments/assets/fafe8aec-4e22-49a8-8553-59fb5c6b00a2 - - - * **Apr 9, 2025**: Experimental support for LLaMA 4 models ([Tutorial](./doc/en/llama4.md)). * **Apr 2, 2025**: Support Multi-concurrency. ([Tutorial](./doc/en/balance-serve.md)). @@ -65,7 +60,7 @@ https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285

- **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 14GB VRAM and 382GB DRAM([Tutorial](./doc/en/DeepseekR1_V3_tutorial.md)). - + - Prefill Speed (tokens/s): - KTransformers: 54.21 (32 cores) → 74.362 (dual-socket, 2×32 cores) → 255.26 (optimized AMX-based MoE kernel, V0.3 only) → 286.55 (selectively using 6 experts, V0.3 only) - Compared to 10.31 tokens/s in llama.cpp with 2×32 cores, achieving up to **27.79× speedup**. @@ -131,7 +126,6 @@ we have already supported vendors: - Kunpeng - AMD - ### 📥 Installation To install KTransformers, follow the official [Installation Guide](https://kvcache-ai.github.io/ktransformers/en/install.html). @@ -201,3 +195,4 @@ If you have any questions, feel free to open an issue. Alternatively, you can jo

🙋 FAQ

Some common questions are answered in the [FAQ](doc/en/FAQ.md). + diff --git a/csrc/balance_serve/CMakeLists.txt b/csrc/balance_serve/CMakeLists.txt index 4a78161..08fc430 100644 --- a/csrc/balance_serve/CMakeLists.txt +++ b/csrc/balance_serve/CMakeLists.txt @@ -10,10 +10,10 @@ message(STATUS "Using compiler: ${CMAKE_CXX_COMPILER}") project(balance_serve VERSION 0.1.0) set(CMAKE_CXX_STANDARD 20) -# set(CMAKE_CXX_FLAGS "-Og -march=native -Wall -Wextra -g -fPIC") -# set(CMAKE_BUILD_TYPE "Debug") -set(CMAKE_CXX_FLAGS "-O3 -march=native -Wall -Wextra -fPIC") -set(CMAKE_BUILD_TYPE "Release") +set(CMAKE_CXX_FLAGS "-Og -march=native -Wall -Wextra -g -fPIC") +set(CMAKE_BUILD_TYPE "Debug") +# set(CMAKE_CXX_FLAGS "-O3 -march=native -Wall -Wextra -fPIC") +# set(CMAKE_BUILD_TYPE "Release") if(NOT DEFINED _GLIBCXX_USE_CXX11_ABI) diff --git a/csrc/balance_serve/sched/model_config.h b/csrc/balance_serve/sched/model_config.h index e7512c4..78fc8dc 100644 --- a/csrc/balance_serve/sched/model_config.h +++ b/csrc/balance_serve/sched/model_config.h @@ -15,16 +15,14 @@ using ModelName = std::string; class ModelConfig { public: DimSize hidden_size; - DimSize intermediate_size; size_t max_position_embeddings; - std::string model_type; size_t num_attention_heads; size_t num_hidden_layers; size_t num_key_value_heads; size_t vocab_size; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, intermediate_size, - max_position_embeddings, model_type, + NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, + max_position_embeddings, num_attention_heads, num_hidden_layers, num_key_value_heads, vocab_size); diff --git a/csrc/ktransformers_ext/ext_bindings.cpp b/csrc/ktransformers_ext/ext_bindings.cpp index f0aeaa5..a6a717b 100644 --- a/csrc/ktransformers_ext/ext_bindings.cpp +++ b/csrc/ktransformers_ext/ext_bindings.cpp @@ -683,12 +683,12 @@ PYBIND11_MODULE(cpuinfer_ext, m) { py::class_(moe_module, "MOEConfig") .def(py::init([](int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int stride, int group_min_len, - int group_max_len, intptr_t gate_proj, + int group_max_len, bool use_silu, intptr_t gate_proj, intptr_t up_proj, intptr_t down_proj, int gate_type, int up_type, int down_type, int hidden_type) { return MOEConfig(expert_num, routed_expert_num, hidden_size, intermediate_size, stride, group_min_len, - group_max_len, (void *)gate_proj, (void *)up_proj, + group_max_len, use_silu, (void *)gate_proj, (void *)up_proj, (void *)down_proj, (ggml_type)gate_type, (ggml_type)up_type, (ggml_type)down_type, (ggml_type)hidden_type); @@ -703,11 +703,11 @@ PYBIND11_MODULE(cpuinfer_ext, m) { py::class_(moe_module, "AMX_MOEConfig") .def(py::init([](int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, - int max_len, intptr_t gate_proj, + int max_len, bool use_silu, intptr_t gate_proj, intptr_t up_proj, intptr_t down_proj) { return AMX_MOEConfig(expert_num, routed_expert_num, hidden_size, intermediate_size, - max_len, (void *)gate_proj, + max_len, use_silu, (void *)gate_proj, (void *)up_proj, (void *)down_proj); })); diff --git a/csrc/ktransformers_ext/operators/amx/moe.hpp b/csrc/ktransformers_ext/operators/amx/moe.hpp index 81df642..662189f 100644 --- a/csrc/ktransformers_ext/operators/amx/moe.hpp +++ b/csrc/ktransformers_ext/operators/amx/moe.hpp @@ -69,22 +69,29 @@ static inline __m512 act_fn(__m512 gate_val, __m512 up_val) { return _mm512_mul_ps(act_val, up_val); } +static inline __m512 relu_act_fn(__m512 gate_val, __m512 up_val) { + __m512 zero_vec = _mm512_setzero_ps(); + __m512 act_val = _mm512_max_ps(zero_vec, gate_val); + return _mm512_mul_ps(act_val, up_val); +} + struct AMX_MOEConfig { int expert_num; int routed_expert_num; int hidden_size; int intermediate_size; int max_len; + bool use_silu; void *gate_proj; void *up_proj; void *down_proj; AMX_MOEConfig() {} - AMX_MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int max_len, + AMX_MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int max_len, bool use_silu, void *gate_proj, void *up_proj, void *down_proj) : expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size), - intermediate_size(intermediate_size), max_len(max_len), gate_proj(gate_proj), up_proj(up_proj), + intermediate_size(intermediate_size), max_len(max_len), use_silu(use_silu), gate_proj(gate_proj), up_proj(up_proj), down_proj(down_proj) {} }; @@ -336,18 +343,35 @@ public: gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth); up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth); auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth); - for (int i = 0; i < m_local_num_[expert_idx]; i++) { - ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size]; - ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size]; - for (int j = n_start; j < n_end; j += 32) { - __m512 gate_val0, gate_val1, up_val0, up_val1; - avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1); - avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1); - __m512 result0 = act_fn(gate_val0, up_val0); - __m512 result1 = act_fn(gate_val1, up_val1); - avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j)); - } + if (config_.use_silu) { + for (int i = 0; i < m_local_num_[expert_idx]; i++) { + ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size]; + ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size]; + for (int j = n_start; j < n_end; j += 32) { + __m512 gate_val0, gate_val1, up_val0, up_val1; + avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1); + avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1); + __m512 result0 = act_fn(gate_val0, up_val0); + __m512 result1 = act_fn(gate_val1, up_val1); + avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j)); + } + } } + else { + for (int i = 0; i < m_local_num_[expert_idx]; i++) { + ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size]; + ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size]; + for (int j = n_start; j < n_end; j += 32) { + __m512 gate_val0, gate_val1, up_val0, up_val1; + avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1); + avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1); + __m512 result0 = relu_act_fn(gate_val0, up_val0); + __m512 result1 = relu_act_fn(gate_val1, up_val1); + avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j)); + } + } + } + }, nullptr); backend->do_work_stealing_job( diff --git a/csrc/ktransformers_ext/operators/llamafile/moe.cpp b/csrc/ktransformers_ext/operators/llamafile/moe.cpp index cd42691..86e55a2 100644 --- a/csrc/ktransformers_ext/operators/llamafile/moe.cpp +++ b/csrc/ktransformers_ext/operators/llamafile/moe.cpp @@ -10,6 +10,7 @@ #include "moe.h" #include #include +#include #ifdef USE_NUMA #include @@ -134,6 +135,14 @@ static float act_fn(float x) { return x / (1.0f + expf(-x)); } +static float act_fn_relu(float x) { + if(x > 0.0){ + return x; + } else { + return 0.0; + } +} + void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend) { const void* gate_input_ptr; const void* up_input_ptr; @@ -182,8 +191,16 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride; llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); - for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { - s_intermediate_fp32_[expert_idx][i] = act_fn(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i]; + if(config_.use_silu){ + // use silu as act fn + for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { + s_intermediate_fp32_[expert_idx][i] = act_fn(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i]; + } + } else { + // use relu as act fn + for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { + s_intermediate_fp32_[expert_idx][i] = act_fn_relu(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i]; + } } if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) == 0) { float* intermediate_fp32_ptr = s_intermediate_fp32_[expert_idx] + ith * config_.stride; @@ -304,8 +321,14 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float* float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * stride; llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); for (int i = 0; i < m_local_num_[expert_idx]; i++) { - for (int j = ith * stride; j < (ith + 1) * stride; j++) { - m_local_intermediate_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = act_fn(m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size + j]) * m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size + j]; + if(config_.use_silu){ + for (int j = ith * stride; j < (ith + 1) * stride; j++) { + m_local_intermediate_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = act_fn(m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size + j]) * m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size + j]; + } + } else { + for (int j = ith * stride; j < (ith + 1) * stride; j++) { + m_local_intermediate_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = act_fn_relu(m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size + j]) * m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size + j]; + } } float* intermediate_fp32_ptr = m_local_intermediate_fp32_ptr_[expert_idx] + i * config_.intermediate_size + ith * stride; void* down_input_ptr = m_local_down_input_ptr_[expert_idx] + i * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) + ith * stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type); diff --git a/csrc/ktransformers_ext/operators/llamafile/moe.h b/csrc/ktransformers_ext/operators/llamafile/moe.h index 28d7ad3..b568b51 100644 --- a/csrc/ktransformers_ext/operators/llamafile/moe.h +++ b/csrc/ktransformers_ext/operators/llamafile/moe.h @@ -32,6 +32,7 @@ struct MOEConfig { int stride; int group_min_len; int group_max_len; + bool use_silu; void* gate_proj; void* up_proj; void* down_proj; @@ -42,8 +43,8 @@ struct MOEConfig { MOEConfig() {} - MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int stride, int group_min_len, int group_max_len, void* gate_proj, void* up_proj, void* down_proj, ggml_type gate_type, ggml_type up_type, ggml_type down_type, ggml_type hidden_type) - : expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size), intermediate_size(intermediate_size), stride(stride), group_min_len(group_min_len), group_max_len(group_max_len), gate_proj(gate_proj), up_proj(up_proj), down_proj(down_proj), gate_type(gate_type), up_type(up_type), down_type(down_type), hidden_type(hidden_type) {} + MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int stride, int group_min_len, int group_max_len, bool use_silu, void* gate_proj, void* up_proj, void* down_proj, ggml_type gate_type, ggml_type up_type, ggml_type down_type, ggml_type hidden_type) + : expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size), intermediate_size(intermediate_size), stride(stride), group_min_len(group_min_len), group_max_len(group_max_len), use_silu(use_silu), gate_proj(gate_proj), up_proj(up_proj), down_proj(down_proj), gate_type(gate_type), up_type(up_type), down_type(down_type), hidden_type(hidden_type) {} }; class MOE { diff --git a/doc/en/SmallThinker_and_Glm4moe.md b/doc/en/SmallThinker_and_Glm4moe.md new file mode 100644 index 0000000..611f4ed --- /dev/null +++ b/doc/en/SmallThinker_and_Glm4moe.md @@ -0,0 +1,76 @@ +# GLM-4-MoE Support for KTransformers + +## Introduction + +### Overview +We are excited to announce that **KTransformers now supports GLM-4-MoE**. + +- **GLM-4-MoE 110B (bf16)**: ~11 TPS **on a dual-socket CPU with one consumer-grade GPU**, requiring ~440 GB DRAM. +- **GLM-4-MoE 110B (AMX INT8)**: prefill ~309 TPS / decode ~16 TPS **on a dual-socket CPU with one consumer-grade GPU**, requiring ~220 GB DRAM. + +### Model & Resource Links +- **GLM-4-MoE 110B** + - *(to be announced)* + +## Installation Guide + +### 1. Resource Requirements + +| Model | Precision | Experts | DRAM Needed | GPU Memory Needed\* | TPS (approx.) | +| ------------------------- | --------- | ------- | ----------- | ------------------- | ------------------------------ | +| GLM-4-MoE 110B | bf16 | 128 | \~440 GB | 14 GB | \~11 TPS | +| GLM-4-MoE 110B (AMX INT8) | int8 | 128 | \~220 GB | 14 GB | prefill \~309 TPS / decode \~16 TPS | + +\* Exact GPU memory depends on sequence length, batch size, and kernels used. + +### 2. Prepare Models + +```bash +# Example: download original safetensors (adjust to your paths/repos) +# (Fill in actual repos/filenames yourself) + +# GLM-4-MoE 110B +huggingface-cli download --resume-download placeholder-org/Model-TBA \ + --local-dir ./Model-TBA +```` + +### 3. Install KTransformers + +Follow the official Installation Guide. + +```bash +pip install ktransformers # or from source if you need bleeding-edge features +``` + +### 4. Run GLM-4-MoE 110B Inference Server + +```bash +python ktransformers/server/main.py \ + --port 10110 \ + --model_name Glm4MoeForCausalLM \ + --model_path /abs/path/to/GLM-4-MoE-110B-bf16 \ + --optimize_config_path ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml \ + --max_new_tokens 1024 \ + --cache_lens 32768 \ + --chunk_size 256 \ + --max_batch_size 4 \ + --backend_type balance_serve +``` + +### 5. Access Server + +```bash +curl -X POST http://localhost:10110/v1/chat/completions \ + -H "accept: application/json" \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [ + {"role": "user", "content": "hello"} + ], + "model": "GLM-4-MoE-110B", + "temperature": 0.3, + "top_p": 1.0, + "stream": true + }' +``` + diff --git a/ktransformers/configs/config.yaml b/ktransformers/configs/config.yaml index 3bd60f9..4050c8c 100644 --- a/ktransformers/configs/config.yaml +++ b/ktransformers/configs/config.yaml @@ -21,12 +21,12 @@ user: model: # type: transformers - # type: balance_serve - type: ktransformers + type: balance_serve + # type: ktransformers - name: DeepSeek-Coder-V2-Instruct - path: deepseek-ai/DeepSeek-V2-Lite-Chat - gguf_path: ./DeepSeek-V2-Lite-Chat-GGUF + name: SmallThinkerForCausalLM + path: /mnt/data/models/Smallthinker-21B + gguf_path: /mnt/data/models/Smallthinker-21B device: cuda:0 cache_lens: 16384 @@ -67,7 +67,7 @@ attn: page_size: 256 chunk_size: 256 kvc2: - gpu_only: false + gpu_only: true utilization_percentage: 1.0 cpu_memory_size_GB: 500 - disk_path: /mnt/data/kvc \ No newline at end of file + disk_path: /home/wjh/kvc \ No newline at end of file diff --git a/ktransformers/ktransformers b/ktransformers/ktransformers new file mode 120000 index 0000000..598751a --- /dev/null +++ b/ktransformers/ktransformers @@ -0,0 +1 @@ +/home/djw/py311_717/ktransformers/ktransformers \ No newline at end of file diff --git a/ktransformers/models/configuration_glm4_moe.py b/ktransformers/models/configuration_glm4_moe.py new file mode 100644 index 0000000..9906ab0 --- /dev/null +++ b/ktransformers/models/configuration_glm4_moe.py @@ -0,0 +1,242 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/glm4_moe/modular_glm4_moe.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_glm4_moe.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation + + +class Glm4MoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Glm4MoeModel`]. It is used to instantiate a + Glm4Moe model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of [THUDM/GLM-4-100B-A10B](https://huggingface.co/THUDM/GLM-4-100B-A10B). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151552): + Vocabulary size of the Glm4Moe model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Glm4MoeModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 10944): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 46): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 96): + Number of attention heads for each attention layer in the Transformer encoder. + partial_rotary_factor (`float`, *optional*, defaults to 0.5): + The factor of the partial rotary position. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + moe_intermediate_size (`int`, *optional*, defaults to 1408): + Intermediate size of the routed expert. + num_experts_per_tok (`int`, *optional*, defaults to 8): + number of experts per token. + n_shared_experts (`int`, *optional*, defaults to 1): + Number of shared experts. + n_routed_experts (`int`, *optional*, defaults to 128): + Number of routed experts. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + n_group (`int`, *optional*, defaults to 1): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to 1): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + first_k_dense_replace (`int`, *optional*, defaults to 1): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the topk probabilities. + use_qk_norm (`bool`, *optional*, defaults to `False`): + Whether to use query-key normalization in the attention + ```python + >>> from transformers import Glm4MoeModel, Glm4MoeConfig + + >>> # Initializing a Glm4Moe style configuration + >>> configuration = Glm4MoeConfig() + + >>> # Initializing a model from the GLM-4-MOE-100B-A10B style configuration + >>> model = Glm4MoeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "glm4_moe" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Glm4Moe` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.experts.*.gate_proj": "colwise", + "layers.*.mlp.experts.*.up_proj": "colwise", + "layers.*.mlp.experts.*.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151552, + hidden_size=4096, + intermediate_size=10944, + num_hidden_layers=46, + num_attention_heads=96, + partial_rotary_factor=0.5, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=131072, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + moe_intermediate_size=1408, + num_experts_per_tok=8, + n_shared_experts=1, + n_routed_experts=128, + routed_scaling_factor=1.0, + n_group=1, + topk_group=1, + first_k_dense_replace=1, + norm_topk_prob=True, + use_qk_norm=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.partial_rotary_factor = partial_rotary_factor + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + # MoE arguments + self.moe_intermediate_size = moe_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.n_group = n_group + self.topk_group = topk_group + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.routed_scaling_factor = routed_scaling_factor + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.use_qk_norm = use_qk_norm + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Glm4MoeConfig"] \ No newline at end of file diff --git a/ktransformers/models/custom_modeling_glm4_moe.py b/ktransformers/models/custom_modeling_glm4_moe.py new file mode 100644 index 0000000..f33c177 --- /dev/null +++ b/ktransformers/models/custom_modeling_glm4_moe.py @@ -0,0 +1,124 @@ +""" +Date: 2024-11-06 10:05:11 +LastEditors: djw +LastEditTime: 2024-11-13 07:50:51 +""" + +import math +from dataclasses import dataclass +import torch +import torch.nn as nn +from torch.nn import functional as F +import math +from typing import List, Optional, Tuple, Union +import torch +import torch.utils.checkpoint +from torch import nn +from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput +from ktransformers.models.custom_cache import KGQACache +from ktransformers.models.modeling_glm4_moe import Glm4MoeModel, Glm4MoePreTrainedModel +from ktransformers.models.configuration_glm4_moe import Glm4MoeConfig +from ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn + +torch.set_grad_enabled(False) +torch.set_default_dtype(torch.bfloat16) +import flashinfer + +class KGlm4MoeForCausalLM(Glm4MoePreTrainedModel): + + cache: KGQACache + use_cuda_graph = False + def __init__( + self, + config: Glm4MoeConfig, + cache, + ): + + super().__init__(config) + self.model = Glm4MoeModel(config) + self.config = config + self.cache = cache + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.attn = [None] * 100 + + def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0): + self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device) + + + def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"): + features = [] + for i in range(batch.batch_size): + tokens = batch.minibatch.tokens.contiguous() + feature = ( + self.model.embed_tokens(tokens.to(torch.device('cpu'))) + .to(torch.bfloat16) + .to(device=device) + ) + features.append(feature) + + return features + + + def forward( + self, + batch: ForwardBatchInput | None = None, + features: List[torch.Tensor] | None = None, + bsz_tensors: torch.Tensor | None = None, + num_tokens_tensors: torch.Tensor | None = None, + page_idx: torch.Tensor | None = None, + page_offset: torch.Tensor | None = None, + cuda_graph_idx: int | None = 0 + ) -> ForwardBatchOutput: + current_stream = torch.cuda.current_stream() + + forward_batch_output = ForwardBatchOutput() + + + hidden_states = features[0] + self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0]) + + freqs_cis = self.model.rotary_emb(hidden_states.unsqueeze(0), batch.minibatch.position_ids.unsqueeze(0)) + + + with torch.cuda.stream(current_stream): + residual = torch.zeros_like(hidden_states) + for i, decode_layer in enumerate(self.model.layers): + + hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual) + hidden_states = decode_layer.self_attn(hidden_states, self.cache, + freqs_cis, + wrapper=self.attn[cuda_graph_idx], bsz_tensors=num_tokens_tensors, + position_ids=batch.minibatch.position_ids + ) + + hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual) + if i < self.model.config.first_k_dense_replace: + hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors) + else: + hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors, cuda_graph_idx) + # hidden_states = hidden_states.squeeze(0) + + forward_batch_output = ForwardBatchOutput() + with torch.cuda.stream(current_stream): + local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors) + forward_batch_output.logits.append(local_logit) + + return forward_batch_output + + + + def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors, + num_q_heads: int, + num_kv_heads: int, + head_dim: int, + page_size: int, + causal: bool, + q_data_type: torch.dtype, + kv_data_type: torch.dtype, + cuda_graph_idx: int = 0 + ): + minibatch = batch.minibatch + self.attn[cuda_graph_idx].plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, + minibatch.kv_last_page_len, bsz_tensors, num_tokens_tensors, num_q_heads, num_kv_heads, head_dim, page_size, causal=causal, q_data_type=q_data_type, kv_data_type=kv_data_type) + \ No newline at end of file diff --git a/ktransformers/models/modeling_glm4_moe.py b/ktransformers/models/modeling_glm4_moe.py new file mode 100644 index 0000000..32727a8 --- /dev/null +++ b/ktransformers/models/modeling_glm4_moe.py @@ -0,0 +1,649 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/glm4_moe/modular_glm4_moe.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_glm4_moe.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +# from transformers.integrations import use_kernel_forward_from_hub +from transformers.masking_utils import create_causal_mask +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +# from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple +from transformers.utils import auto_docstring, can_return_tuple +# from transformers.utils.generic import check_model_inputs +from .configuration_glm4_moe import Glm4MoeConfig + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + # **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + +class Glm4MoeAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Glm4MoeConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.use_qk_norm = config.use_qk_norm + if self.use_qk_norm: + self.q_norm = Glm4MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Glm4MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape) + key_states = self.k_proj(hidden_states).view(hidden_shape) + value_states = self.v_proj(hidden_states).view(hidden_shape) + + if self.use_qk_norm: # main diff from Llama + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Glm4MoeMLP(nn.Module): + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Glm4MoeTopkRouter(nn.Module): + def __init__(self, config: Glm4MoeConfig): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) + self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts), dtype=torch.float32)) + + @torch.no_grad() + def get_topk_indices(self, scores): + scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.n_group, self.n_routed_experts // self.n_group) + .reshape(-1, self.n_routed_experts) + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + return topk_indices + + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.config.hidden_size) + router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) + scores = router_logits.sigmoid() + topk_indices = self.get_topk_indices(scores) + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return topk_indices, topk_weights + + +# @use_kernel_forward_from_hub("RMSNorm") +class Glm4MoeRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Glm4MoeRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.hidden_size = hidden_size + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Glm4MoeMoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.experts = nn.ModuleList( + [ + Glm4MoeMLP(config, intermediate_size=config.moe_intermediate_size) + for _ in range(config.n_routed_experts) + ] + ) + self.gate = Glm4MoeTopkRouter(config) + self.shared_experts = Glm4MoeMLP( + config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts + ) + + def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): + r""" + CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused + to not have to do a loop here (deepseek has 256 experts soooo yeah). + """ + final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts)) + expert_mask = expert_mask.permute(2, 0, 1) + + for expert_idx in range(len(self.experts)): + expert = self.experts[expert_idx] + mask = expert_mask[expert_idx] + token_indices, weight_indices = torch.where(mask) + + if token_indices.numel() > 0: + expert_weights = topk_weights[token_indices, weight_indices] + expert_input = hidden_states[token_indices] + expert_output = expert(expert_input) + weighted_output = expert_output * expert_weights.unsqueeze(-1) + final_hidden_states.index_add_(0, token_indices, weighted_output) + + # in original deepseek, the output of the experts are gathered once we leave this module + # thus the moe module is itelsf an IsolatedParallel module + # and all expert are "local" meaning we shard but we don't gather + return final_hidden_states.type(hidden_states.dtype) + + def forward(self, hidden_states): + residuals = hidden_states + orig_shape = hidden_states.shape + topk_indices, topk_weights = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states + + +class Glm4MoeDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Glm4MoeConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Glm4MoeAttention(config=config, layer_idx=layer_idx) + + if layer_idx >= config.first_k_dense_replace: + self.mlp = Glm4MoeMoE(config) + else: + self.mlp = Glm4MoeMLP(config) + + self.input_layernorm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + # **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class Glm4MoePreTrainedModel(PreTrainedModel): + config: Glm4MoeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Glm4MoeDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_static_cache = False + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Glm4MoeDecoderLayer, + "attentions": Glm4MoeAttention, + } + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Glm4MoeRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, Glm4MoeTopkRouter): + module.weight.data.normal_(mean=0.0, std=std) + + +class Glm4MoeRotaryEmbedding(nn.Module): + def __init__(self, config: Glm4MoeConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class Glm4MoeModel(Glm4MoePreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"model\.layers\.92.*", r"model\.layers\.46.*"] + + def __init__(self, config: Glm4MoeConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Glm4MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Glm4MoeRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + # **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position: torch.Tensor = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class Glm4MoeForCausalLM(Glm4MoePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Glm4MoeModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + # **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Glm4MoeForCausalLM + + >>> model = Glm4MoeForCausalLM.from_pretrained("meta-glm4_moe/Glm4Moe-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-glm4_moe/Glm4Moe-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + # **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["Glm4MoePreTrainedModel", "Glm4MoeModel", "Glm4MoeForCausalLM"] \ No newline at end of file diff --git a/ktransformers/operators/RoPE.py b/ktransformers/operators/RoPE.py index 85d6556..968c7b9 100644 --- a/ktransformers/operators/RoPE.py +++ b/ktransformers/operators/RoPE.py @@ -26,6 +26,8 @@ from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_loader import GGUFLoader from ktransformers.util.utils import InferenceState from transformers.configuration_utils import PretrainedConfig +from ktransformers.models.modeling_smallthinker import SmallthinkerRotaryEmbedding +from ktransformers.models.modeling_glm4_moe import Glm4MoeRotaryEmbedding import torch # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe @@ -437,4 +439,92 @@ class KQwen3MoeRotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding): def load(self): self.orig_module.__init__( self.orig_module.config - ) \ No newline at end of file + ) + + +class KSmallthinkerRotaryEmbedding(BaseInjectedModule, SmallthinkerRotaryEmbedding): + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + # device: str = "cuda", + generate_device: str = "cuda", + prefill_device: str = "cuda", + **kwargs, + ): + BaseInjectedModule.__init__( + self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs + ) + self.orig_module.__init__( + config + ) + self.generate_device = generate_device + self.prefill_device = prefill_device + + def load(self): + self.orig_module.__init__( + self.orig_module.config, + device = self.generate_device, + ) + + @torch.no_grad() + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + +class KGlm4MoeRotaryEmbedding(BaseInjectedModule, Glm4MoeRotaryEmbedding): + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + # device: str = "cuda", + generate_device: str = "cuda", + prefill_device: str = "cuda", + **kwargs, + ): + BaseInjectedModule.__init__( + self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs + ) + self.orig_module.__init__( + config + ) + self.generate_device = generate_device + self.prefill_device = prefill_device + + def load(self): + self.orig_module.__init__( + self.orig_module.config, + device = self.generate_device, + ) + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + # print(inv_freq_expanded.device) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) \ No newline at end of file diff --git a/ktransformers/operators/balance_serve_attention.py b/ktransformers/operators/balance_serve_attention.py index 51695f3..d493329 100644 --- a/ktransformers/operators/balance_serve_attention.py +++ b/ktransformers/operators/balance_serve_attention.py @@ -9,6 +9,8 @@ from torch import nn from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb from ktransformers.models.modeling_qwen2_moe import Qwen2MoeAttention from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention +from ktransformers.models.modeling_smallthinker import SmallthinkerAttention +from ktransformers.models.modeling_glm4_moe import Glm4MoeAttention from typing import Optional, Tuple from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_loader import GGUFLoader @@ -454,4 +456,191 @@ class deepseek_torch_attn(BaseInjectedModule, DeepseekV2Attention): attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output, batch_num_tokens_tensors) final_attention_output = torch.cat((final_attention_output, attn_output), dim=0) - return final_attention_output \ No newline at end of file + return final_attention_output + +class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention): + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + chunck_size: int = 1000, + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(orig_module.config, + orig_module.layer_idx) + self.chunck_size = chunck_size # TODO, generate chunck_size automatically. + + def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def forward(self, + hidden_states: torch.Tensor, + kv_cache: KGQACache, + freqs_cis: torch.Tensor, + wrapper: flashInferAttn, + bsz_tensors: torch.Tensor, + position_ids: torch.Tensor = None, + ): + + if self.use_qk_norm: + raise NotImplementedError("use_qk_norm is not implemented yet") + + q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states, bsz_tensors) + key_states = self.k_proj(hidden_states, bsz_tensors) + value_states = self.v_proj(hidden_states, bsz_tensors) + + query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim) + key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim) + + # cos, sin = freqs_cis + """ + print(query_states.shape) + print(key_states.shape) + print(cos.shape) + print(sin.shape) + """ + if freqs_cis: + cos, sin = freqs_cis + query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2) + + + + query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim) + key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim) + + k_cache = kv_cache.get_k_cache(self.layer_idx) + v_cache = kv_cache.get_v_cache(self.layer_idx) + + + attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states) + + + attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors) + + return attn_output + + + + +class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention): + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + chunck_size: int = 1000, + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(orig_module.config, + orig_module.layer_idx) + self.chunck_size = chunck_size # TODO, generate chunck_size automatically. + + def apply_rotary_pos_emb( + self, + q: torch.Tensor, + k: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], + unsqueeze_dim=2 + ) -> Tuple[torch.Tensor, torch.Tensor]: + + # Keep half or full tensor for later concatenation + cos = freqs_cis[0] + sin = freqs_cis[1] + rotary_dim = cos.shape[-1] + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + def forward(self, + hidden_states: torch.Tensor, + kv_cache: KGQACache, + freqs_cis: torch.Tensor, + wrapper: flashInferAttn, + bsz_tensors: torch.Tensor, + position_ids: torch.Tensor = None, + ): + + q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states, bsz_tensors) + key_states = self.k_proj(hidden_states, bsz_tensors) + value_states = self.v_proj(hidden_states, bsz_tensors) + + + if self.use_qk_norm: + query_states = self.q_norm(query_states, bsz_tensors) + key_states = self.k_norm(key_states, bsz_tensors) + + + query_states = query_states.view(q_len, self.config.num_attention_heads, self.head_dim) + key_states = key_states.view(q_len, self.config.num_key_value_heads, self.head_dim) + value_states = value_states.view(q_len, self.config.num_key_value_heads, self.head_dim) + + # cos, sin = freqs_cis + """ + print(query_states.shape) + print(key_states.shape) + print(cos.shape) + print(sin.shape) + """ + if freqs_cis is not None: + query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis) + + + + query_states = query_states.view(q_len, self.config.num_attention_heads, self.head_dim) + key_states = key_states.view(q_len, self.config.num_key_value_heads, self.head_dim) + value_states = value_states.view(q_len, self.config.num_key_value_heads, self.head_dim) + + k_cache = kv_cache.get_k_cache(self.layer_idx) + v_cache = kv_cache.get_v_cache(self.layer_idx) + + + attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states) + + + attn_output = self.o_proj(attn_output.view(q_len, self.config.num_attention_heads * self.head_dim), bsz_tensors) + + return attn_output \ No newline at end of file diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 7a40168..f4131d1 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -194,6 +194,7 @@ class KExpertsCPU(KExpertsBase): 64, 10, 1024, + self.config.hidden_act == 'silu', gate_ptr, up_ptr, down_ptr, @@ -214,6 +215,7 @@ class KExpertsCPU(KExpertsBase): self.config.hidden_size, self.config.moe_intermediate_size, max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size, + self.config.hidden_act == 'silu', gate_ptr, up_ptr, down_ptr, @@ -232,6 +234,7 @@ class KExpertsCPU(KExpertsBase): self.config.hidden_size, self.config.moe_intermediate_size, max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size, + self.config.hidden_act == 'silu', gate_ptr, up_ptr, down_ptr, @@ -729,6 +732,8 @@ from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock from ktransformers.models.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock +from ktransformers.models.modeling_smallthinker import SmallthinkerMoeBlock +from ktransformers.models.modeling_glm4_moe import Glm4MoeMoE class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock): @@ -1248,6 +1253,12 @@ class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase): **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + + if prefill_op == 'None': + prefill_op = None + if generate_op == 'None': + generate_op = None + if generate_op is not None: self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs) else: @@ -1307,6 +1318,152 @@ class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase): else: raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") + +class KSmallthinkerExperts(BaseInjectedModule, KExpertsBase): + def __init__(self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + # device: str = "cuda", + prefill_device:str = "cuda", + prefill_op: str | None = "KExpertsTorch", + generate_device: str = "cpu", + generate_op: str | None = "KExpertsCPU", + **kwargs): + + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + if generate_op is not None: + self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs) + else: + self.generate_experts = None + if prefill_op is not None: + self.prefill_experts = None + self.gpu_mlp_type = prefill_op + self.cpu_mlp_type = generate_op + self.mode = InferenceState.UNLOAD + + def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True): + # TODO support w as input + if not mode: mode = InferenceState.GENERATE + if mode == InferenceState.GENERATE: + # self.prefill_experts.unload() + self.generate_experts.load(w, warmup=warmup) + self.device = self.generate_experts.device + self.mode = mode + elif mode == InferenceState.PREFILL: + self.generate_experts.unload() + self.prefill_experts.load(w, warmup=warmup) + self.device = self.prefill_experts.device + self.mode = mode + elif mode == InferenceState.UNLOAD: + self.unload() + self.mode = mode + self.device = self.generate_experts.device + else: + raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") + + def unload(self): + if self.generate_experts is not None: + self.generate_experts.unload() + if self.prefill_experts is not None: + self.prefill_experts.unload() + self.device = self.generate_experts.device + + def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0): + if self.mode == InferenceState.GENERATE: + assert self.generate_experts is not None, "generate_experts is None" + return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx) + elif self.mode == InferenceState.PREFILL: + assert self.prefill_experts is not None, "prefill_experts is None" + return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx) + else: + raise ValueError("load or set_inference_mode before forward") + + def set_inference_mode(self, mode: InferenceState): + if mode == InferenceState.GENERATE: + self.load(mode=InferenceState.GENERATE, warmup=False) + elif mode == InferenceState.PREFILL: + self.load(mode=InferenceState.PREFILL, warmup=False) + elif mode == InferenceState.UNLOAD: + self.unload() + else: + raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") + +class KGlm4Experts(BaseInjectedModule, KExpertsBase): + def __init__(self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + # device: str = "cuda", + prefill_device:str = "cuda", + prefill_op: str | None = "KExpertsTorch", + generate_device: str = "cpu", + generate_op: str | None = "KExpertsCPU", + **kwargs): + + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + if generate_op is not None: + self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs) + else: + self.generate_experts = None + if prefill_op is not None: + self.prefill_experts = None + self.gpu_mlp_type = prefill_op + self.cpu_mlp_type = generate_op + self.mode = InferenceState.UNLOAD + + def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True): + # TODO support w as input + if not mode: mode = InferenceState.GENERATE + if mode == InferenceState.GENERATE: + # self.prefill_experts.unload() + self.generate_experts.load(w, warmup=warmup) + self.device = self.generate_experts.device + self.mode = mode + elif mode == InferenceState.PREFILL: + self.generate_experts.unload() + self.prefill_experts.load(w, warmup=warmup) + self.device = self.prefill_experts.device + self.mode = mode + elif mode == InferenceState.UNLOAD: + self.unload() + self.mode = mode + self.device = self.generate_experts.device + else: + raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") + + def unload(self): + if self.generate_experts is not None: + self.generate_experts.unload() + if self.prefill_experts is not None: + self.prefill_experts.unload() + self.device = self.generate_experts.device + + def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0): + if self.mode == InferenceState.GENERATE: + assert self.generate_experts is not None, "generate_experts is None" + return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx) + elif self.mode == InferenceState.PREFILL: + assert self.prefill_experts is not None, "prefill_experts is None" + return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx) + else: + raise ValueError("load or set_inference_mode before forward") + + def set_inference_mode(self, mode: InferenceState): + if mode == InferenceState.GENERATE: + self.load(mode=InferenceState.GENERATE, warmup=False) + elif mode == InferenceState.PREFILL: + self.load(mode=InferenceState.PREFILL, warmup=False) + elif mode == InferenceState.UNLOAD: + self.unload() + else: + raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") + + class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock): def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0): @@ -1507,6 +1664,246 @@ class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock): ) return outs + @torch.no_grad() + # TODO may bugs here + def moe_infer(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert.forward(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out + + +class KSmallthinkerMoeBlock(BaseInjectedModule, SmallthinkerMoeBlock): + def forward(self, router_input: torch.Tensor, hidden_states: torch.Tensor, bsz_tensor=None, cuda_graph_idx=0): + + orig_shape = hidden_states.shape + sequence_length = orig_shape[1] + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + if bsz_tensor is None: + if self.enable_early_router: + router_logits = self.primary_router(router_input) + else: + router_logits = self.primary_router(hidden_states) + else: + if self.enable_early_router: + router_logits = self.primary_router(router_input, bsz_tensor) + else: + router_logits = self.primary_router(hidden_states, bsz_tensor) + + router_logits, selected_experts = torch.topk(router_logits, self.num_active_primary_experts, dim=-1) + + + if router_logits.device.type == "xpu": + # TODO: support self.moe_primary_router_apply_softmax False case + from ipex_llm.transformers.models.common import moe_softmax_topk + selected_experts, routing_weights = moe_softmax_topk( + router_logits.half(), self.top_k, self.norm_topk_prob + ) + else: + if self.moe_primary_router_apply_softmax: + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + else: + routing_weights = F.sigmoid(router_logits) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + # only for generate phase + if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug + self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx) + # y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0) + # y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ + + y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0) + + # y += y_ + y.resize_(*orig_shape) + return y + + # y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0) + # y_ = ( + # F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ + # ) + + + if isinstance(self.experts, KExpertsBase): + y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device) + elif hidden_states.size(0) > 10: + # TODO may bugs here + y = ( + self.moe_infer(hidden_states, selected_experts, routing_weights) + .view(*orig_shape) + .to(device=hidden_states.device) + ) + else: + # TODO may bugs here + y = ( + self.moe_infer_simple(hidden_states, selected_experts, routing_weights) + .view(*orig_shape) + .to(device=hidden_states.device) + ) + # y += y_ + return y + + @torch.no_grad() + def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor: + outs = torch.empty_like(x) + outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx) + return outs + + @torch.no_grad() + # TODO may bugs here + def moe_infer_simple( + self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor + ) -> torch.Tensor: + """ + x: [num_tokens, hidden_size] + topk_ids, topk_weight: [num_tokens, num_selected_experts] + """ + outs = torch.zeros_like(x) + for token_idx in range(topk_ids.size(0)): + for expert_idx in range(topk_ids.size(1)): + expert = self.experts[topk_ids[token_idx, expert_idx]] + outs[token_idx] += ( + expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx] + ) + return outs + + @torch.no_grad() + # TODO may bugs here + def moe_infer(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert.forward(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out + + +class KGlm4MoeMoE(BaseInjectedModule, Glm4MoeMoE): + def forward(self, hidden_states, bsz_tensor=None, cuda_graph_idx=0): + + orig_shape = hidden_states.shape + sequence_length = orig_shape[1] + + topk_idx, topk_weight = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + # only for generate phase + if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug + self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx) + y_ = self.shared_experts(hidden_states, bsz_tensor).squeeze(0) + # y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ + + y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0) + + y += y_ + y.resize_(*orig_shape) + return y + + y_ = self.shared_experts(hidden_states, bsz_tensor).squeeze(0) + # y_ = ( + # F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ + # ) + + + if isinstance(self.experts, KExpertsBase): + y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device) + elif hidden_states.size(0) > 10: + # TODO may bugs here + y = ( + self.moe_infer(hidden_states, topk_idx, topk_weight) + .view(*orig_shape) + .to(device=hidden_states.device) + ) + else: + # TODO may bugs here + y = ( + self.moe_infer_simple(hidden_states, topk_idx, topk_weight) + .view(*orig_shape) + .to(device=hidden_states.device) + ) + y += y_ + return y + + @torch.no_grad() + def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor: + outs = torch.empty_like(x) + outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx) + return outs + + @torch.no_grad() + # TODO may bugs here + def moe_infer_simple( + self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor + ) -> torch.Tensor: + """ + x: [num_tokens, hidden_size] + topk_ids, topk_weight: [num_tokens, num_selected_experts] + """ + outs = torch.zeros_like(x) + for token_idx in range(topk_ids.size(0)): + for expert_idx in range(topk_ids.size(1)): + expert = self.experts[topk_ids[token_idx, expert_idx]] + outs[token_idx] += ( + expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx] + ) + return outs + @torch.no_grad() # TODO may bugs here def moe_infer(self, x, topk_ids, topk_weight): diff --git a/ktransformers/operators/gate.py b/ktransformers/operators/gate.py index f5f96c1..b9ccc01 100644 --- a/ktransformers/operators/gate.py +++ b/ktransformers/operators/gate.py @@ -212,4 +212,5 @@ class KMoEGateIPEXLLM(KMoEGate): topk_idx, topk_weight = moe_group_topk(scores, self.orig_module.e_score_correction_bias, self.n_group, self.topk_group, self.top_k, self.norm_topk_prob, self.routed_scaling_factor) - return topk_idx, topk_weight.to(x.dtype) \ No newline at end of file + return topk_idx, topk_weight.to(x.dtype) + diff --git a/ktransformers/operators/layernorm.py b/ktransformers/operators/layernorm.py index 796592c..24bdc81 100644 --- a/ktransformers/operators/layernorm.py +++ b/ktransformers/operators/layernorm.py @@ -28,6 +28,8 @@ import torch.nn as nn from ktransformers.models.modeling_deepseek_v3 import DeepseekV3RMSNorm from ktransformers.models.modeling_qwen2_moe import Qwen2MoeRMSNorm from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm +from ktransformers.models.modeling_smallthinker import SmallthinkerRMSNorm +from ktransformers.models.modeling_glm4_moe import Glm4MoeRMSNorm from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_loader import GGUFLoader if not torch.xpu.is_available(): @@ -164,6 +166,94 @@ class KQwen3MoeRMSNorm(Qwen3MoeRMSNorm, BaseInjectedModule): variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) + +class KSmallthinkerRMSNorm(SmallthinkerRMSNorm, BaseInjectedModule): + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(orig_module.hidden_size, + orig_module.variance_epsilon) + + def forward( + self, + x: torch.Tensor, + batch_size_tensor: torch.Tensor = None, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + #return self.forward_native(x, residual) + bsz, hidden_size = x.shape + x = x.view(-1, self.orig_module.hidden_size) + if batch_size_tensor is None: + return self.forward_native(x) + if residual is not None: + fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon) + #residual = x + residual + #out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon) + return x, residual + # print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous()) + out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon) + out = out.view(bsz, hidden_size) + return out + + def forward_native( + self, hidden_states + ): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class KGlm4MoeRMSNorm(Glm4MoeRMSNorm, BaseInjectedModule): + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(orig_module.hidden_size, + orig_module.variance_epsilon) + + def forward( + self, + x: torch.Tensor, + batch_size_tensor: torch.Tensor = None, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + #return self.forward_native(x, residual) + bsz, hidden_size = x.shape + x = x.view(-1, self.orig_module.hidden_size) + if batch_size_tensor is None: + return self.forward_native(x) + if residual is not None: + fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon) + #residual = x + residual + #out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon) + return x, residual + # print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous()) + out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon) + out = out.view(bsz, hidden_size) + return out + + def forward_native( + self, hidden_states + ): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + class DeepseekV3RMSNormTorch(DeepseekV3RMSNorm, BaseInjectedModule): def __init__(self, diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index 654c9f9..e232f83 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -88,10 +88,17 @@ class KLinearBase(ABC): if isinstance(self.gguf_loader, SafeTensorLoader): # using safetensor_loader tensor = self.gguf_loader.load_tensor(key+'.weight') + try: + bias = self.gguf_loader.load_tensor(key+'.bias') + except: + bias = None if self.gguf_loader.has_tensor(key+'.weight_scale_inv'): weight_scale_inv = self.gguf_loader.load_tensor(key+'.weight_scale_inv') return nn.Parameter(tensor), nn.Parameter(weight_scale_inv) - return nn.Parameter(tensor) + if bias is not None: + return nn.Parameter(tensor), nn.Parameter(bias) + else: + return nn.Parameter(tensor) elif self.gguf_loader.has_tensor(key + ".weight") or "kv_b_proj" in key: if key + ".bias" in self.gguf_loader.tensor_file_map: diff --git a/ktransformers/operators/mlp.py b/ktransformers/operators/mlp.py index 77d7d05..6d3e812 100644 --- a/ktransformers/operators/mlp.py +++ b/ktransformers/operators/mlp.py @@ -5,6 +5,8 @@ from transformers import PretrainedConfig import torch.nn as nn from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP from ktransformers.models.modeling_qwen2_moe import Qwen2MoeMLP +from ktransformers.models.modeling_smallthinker import SmallthinkerDenseMlpBlock +from ktransformers.models.modeling_glm4_moe import Glm4MoeMLP class kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule): def __init__(self, key: str, @@ -32,6 +34,37 @@ class KQwen2MoeMLP(Qwen2MoeMLP, BaseInjectedModule): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.intermediate_size) + def forward(self, x, bsz_tensor): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor) + return down_proj + + +class KSmallthinkerDenseMlpBlock(SmallthinkerDenseMlpBlock, BaseInjectedModule): + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(orig_module.config) + def forward(self, x, bsz_tensor): + down_proj = self.down(nn.functional.relu(self.gate(x, bsz_tensor)) * self.up(x, bsz_tensor), bsz_tensor) + return down_proj + +class KGlm4MoeMLP(Glm4MoeMLP, BaseInjectedModule): + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(orig_module.config, orig_module.hidden_size, orig_module.intermediate_size) def forward(self, x, bsz_tensor): down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor) return down_proj \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml b/ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml new file mode 100644 index 0000000..56345df --- /dev/null +++ b/ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml @@ -0,0 +1,90 @@ +- match: + class: ktransformers.models.modeling_glm4_moe.Glm4MoeRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.KGlm4MoeRotaryEmbedding + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + name: "^lm_head$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "VLinearMarlin" + prefill_op: "KLinearTorch" + +# - match: +# name: "^model\\.layers\\..*$" # regular expression +# class: torch.nn.Linear # only match modules matching name and class simultaneously +# replace: +# class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types +# kwargs: +# generate_device: "cuda" +# prefill_device: "cuda" +# generate_op: "VLinearMarlin" +# prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_glm4_moe.Glm4MoeMoE + replace: + class: ktransformers.operators.experts.KGlm4MoeMoE + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KGlm4Experts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda" + prefill_op: None + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.balance_serve_attention.KGlm4MoeAttention # optimized MLA implementation + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +- match: + class: ktransformers.models.modeling_glm4_moe.Glm4MoeRMSNorm + replace: + class: ktransformers.operators.layernorm.KGlm4MoeRMSNorm + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + class: ktransformers.models.modeling_glm4_moe.Glm4MoeMLP + replace: + class: ktransformers.operators.mlp.KGlm4MoeMLP + kwargs: + generate_device: "cuda" + prefill_device: "cuda" \ No newline at end of file diff --git a/ktransformers/server/args.py b/ktransformers/server/args.py index 748bd47..ecf1f84 100644 --- a/ktransformers/server/args.py +++ b/ktransformers/server/args.py @@ -2,6 +2,9 @@ import argparse from ktransformers.server.backend.args import ConfigArgs, default_args from ktransformers.util.utils import get_free_ports from transformers import AutoConfig +from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig +from ktransformers.models.configuration_smallthinker import SmallthinkerConfig +from ktransformers.models.configuration_glm4_moe import Glm4MoeConfig class ArgumentParser: def __init__(self, cfg): @@ -135,9 +138,16 @@ class ArgumentParser: self.cfg.server_ip = args.host self.cfg.server_port = args.port self.cfg.user_force_think = args.force_think - - model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) - if model_config.architectures[0] == "Qwen3MoeForCausalLM" or model_config.architectures[0] == "Qwen2MoeForCausalLM" : + try: + model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) + except: + try: + model_config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True) + except: + raise ValueError(f"Model {args.model_name} not supported. Please check your model directory or model name.") + + + if model_config.architectures[0] == "Qwen3MoeForCausalLM" or model_config.architectures[0] == "Qwen2MoeForCausalLM" or model_config.architectures[0] == "SmallThinkerForCausalLM" or model_config.architectures[0] == "Glm4MoeForCausalLM": args.gpu_memory_size = args.cache_lens*2*2*model_config.num_hidden_layers*model_config.num_key_value_heads*model_config.head_dim args.architectures = model_config.architectures[0] else: diff --git a/ktransformers/server/backend/interfaces/balance_serve.py b/ktransformers/server/backend/interfaces/balance_serve.py index d3d6dcf..045f7a1 100644 --- a/ktransformers/server/backend/interfaces/balance_serve.py +++ b/ktransformers/server/backend/interfaces/balance_serve.py @@ -24,7 +24,11 @@ from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausa from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM +from ktransformers.models.custom_modeling_smallthinker import KSmallThinkerForCausalLM +from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig +from ktransformers.models.configuration_smallthinker import SmallthinkerConfig +from ktransformers.models.configuration_glm4_moe import Glm4MoeConfig from ktransformers.server.balance_serve.inference.model_runner import ModelRunner from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions from ktransformers.server.balance_serve.inference.query_manager import QueryManager @@ -60,6 +64,8 @@ default_optimize_rules = { "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml", "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-serve.yaml", "Qwen3MoeForCausalLM": ktransformer_rules_dir + "Qwen3Moe-serve.yaml", + "SmallThinkerForCausalLM": ktransformer_rules_dir + "Smallthinker-serve.yaml", + "Glm4MoeForCausalLM": ktransformer_rules_dir + "Glm4Moe-serve.yaml", } @@ -123,15 +129,25 @@ class Engine: self.sched_client = SchedulerClient(args.sched_port) self.updates = [] - try: - config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) - except: - if args.model_name == "Qwen3Moe": - config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True) - else: - assert False, f"model {args.model_name} not supported" + print(f"args.architectures: {args.architectures}") + + if args.architectures == "Qwen3MoeForCausalLM": + config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True) + elif args.architectures == "Glm4MoeForCausalLM": + config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True) + elif args.architectures == "SmallThinkerForCausalLM": + config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True) + config._attn_implementation = "eager" + config.moe_intermediate_size = config.moe_ffn_hidden_size + else: + try: + config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) + except: + raise ValueError(f"Model {args.architectures} not supported. Please check your model directory or model name.") + + + - self.gen_queue = generated_token_queue with torch.device("meta"): @@ -147,6 +163,13 @@ class Engine: self.model = KQwen2MoeForCausalLM(config, self.cache) else: self.model = KQwen3MoeForCausalLM(config, self.cache) + elif config.architectures[0] == "SmallThinkerForCausalLM": + self.cache = KGQACache(config, self.args.page_size) + self.model = KSmallThinkerForCausalLM(config, self.cache) + elif config.architectures[0] == "Glm4MoeForCausalLM": + self.cache = KGQACache(config, self.args.page_size) + self.model = KGlm4MoeForCausalLM(config, self.cache) + context = zmq.Context() @@ -197,7 +220,7 @@ class Engine: self.block_num = inference_context.k_cache[0].size(1) self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num) #@TODO add config - if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM": + if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallThinkerForCausalLM": self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num) else: self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num) diff --git a/ktransformers/server/balance_serve/inference/model_runner.py b/ktransformers/server/balance_serve/inference/model_runner.py index 55dfb6d..75fb169 100644 --- a/ktransformers/server/balance_serve/inference/model_runner.py +++ b/ktransformers/server/balance_serve/inference/model_runner.py @@ -29,6 +29,8 @@ from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausa from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM +from ktransformers.models.custom_modeling_smallthinker import KSmallThinkerForCausalLM +from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM from ktransformers.server.balance_serve.inference.query_manager import QueryManager from ktransformers.server.balance_serve.settings import sched_ext @@ -53,7 +55,7 @@ def generate_cuda_graphs(chunk_size: int) -> list: class ModelRunner: """A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile.""" - model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM + model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM | KSmallThinkerForCausalLM | KGlm4MoeForCausalLM input: ForwardBatchInput | list[ForwardBatchInput] output: ForwardBatchOutput @@ -93,7 +95,7 @@ class ModelRunner: num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True, sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) - elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM): + elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallThinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM): self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf, num_q_heads=self.model.config.num_attention_heads, num_kv_heads=self.model.config.num_key_value_heads, head_dim=self.model.config.head_dim if hasattr(self.model.config, 'head_dim') else self.model.config.hidden_size // self.model.config.num_attention_heads, @@ -124,7 +126,7 @@ class ModelRunner: num_tokens = self.features_buf[i][0].size(0) print("capturing cuda graph", batch_size, num_tokens) - if isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM): + if isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallThinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM): self.model.init_wrapper(self.use_cuda_graph, self.device, num_tokens ,batch_size, self.block_num, i) # TODO: 1024 is a magic number(max_batch_tokens) self.bsz_tensor_buf[0] = batch_size diff --git a/ktransformers/server/balance_serve/sched_rpc.py b/ktransformers/server/balance_serve/sched_rpc.py index 218d1d3..51556f8 100644 --- a/ktransformers/server/balance_serve/sched_rpc.py +++ b/ktransformers/server/balance_serve/sched_rpc.py @@ -10,7 +10,7 @@ current_file_path = os.path.abspath(__file__) # sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..")) import pickle import argparse -from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings, create_sched_settings_qwen2moe, create_sched_settings_qwen3moe +from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings, create_sched_settings_qwen2moe, create_sched_settings_qwen3moe, create_sched_settings_glm4moe, create_sched_settings_smallthinker @@ -213,6 +213,10 @@ if __name__ == '__main__': settings = create_sched_settings_qwen2moe(main_args) elif main_args.architectures == "Qwen3MoeForCausalLM": settings = create_sched_settings_qwen3moe(main_args) + elif main_args.architectures == "Glm4MoeForCausalLM": + settings = create_sched_settings_glm4moe(main_args) + elif main_args.architectures == "SmallThinkerForCausalLM": + settings = create_sched_settings_smallthinker(main_args) else: settings = create_sched_settings(main_args) start_server(settings, main_args) diff --git a/ktransformers/server/balance_serve/settings.py b/ktransformers/server/balance_serve/settings.py index 40a29bf..0d62df1 100644 --- a/ktransformers/server/balance_serve/settings.py +++ b/ktransformers/server/balance_serve/settings.py @@ -12,6 +12,8 @@ import sched_ext from transformers import AutoConfig from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig +from ktransformers.models.configuration_glm4_moe import Glm4MoeConfig +from ktransformers.models.configuration_smallthinker import SmallthinkerConfig def create_sched_settings(args): default_sample_options = sched_ext.SampleOptions() @@ -172,6 +174,110 @@ def create_sched_settings_qwen3moe(args): settings.auto_derive() return settings +def create_sched_settings_glm4moe(args): + default_sample_options = sched_ext.SampleOptions() + model_name = os.path.basename(os.path.normpath(args.model_dir)) + input_model_settings = sched_ext.ModelSettings() + input_model_settings.model_path = args.model_dir + input_model_settings.params_count = int(0) + model_config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True) + input_model_settings.layer_count = model_config.num_hidden_layers + input_model_settings.num_k_heads = model_config.num_key_value_heads # model_config["num_key_value_heads"] + input_model_settings.k_head_dim = 128 + input_model_settings.bytes_per_params = 2 + input_model_settings.bytes_per_kv_cache_element = 2 + settings = sched_ext.Settings() + settings.model_name = model_name + settings.quant_type = "BF16" + settings.model_settings = input_model_settings + settings.page_size = args.page_size + settings.gpu_device_count = 1 # tp + settings.gpu_device_id = [i for i in range(settings.gpu_device_count)] + # settings.gpu_memory_size = args.cache_lens*576*2 + settings.gpu_memory_size = args.gpu_memory_size + settings.memory_utilization_percentage = args.utilization_percentage + max_batch_size = args.max_batch_size + chunk_size = args.chunk_size + + max_decode_batch_size = max_batch_size - 2 + + settings.max_batch_size = max_batch_size + settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2 + settings.sample_options = default_sample_options + settings.sched_metrics_port = args.sched_metrics_port + settings.gpu_only = args.memory_gpu_only + settings.use_self_defined_head_dim = False + settings.self_defined_head_dim = 576 + settings.full_kv_cache_on_each_gpu = True + settings.k_cache_on = True + settings.v_cache_on = True + + settings.kvc2_root_path = args.kvc2_disk_path + settings.kvc2_config_path = args.kvc2_config_dir + settings.memory_pool_size_GB = args.cpu_memory_size_GB + settings.evict_count = 40 + settings.kvc2_metrics_port = args.kvc2_metrics_port + settings.load_from_disk = False + settings.save_to_disk = True + + + settings.strategy_name = args.sched_strategy + + settings.auto_derive() + return settings + +def create_sched_settings_smallthinker(args): + default_sample_options = sched_ext.SampleOptions() + model_name = os.path.basename(os.path.normpath(args.model_dir)) + input_model_settings = sched_ext.ModelSettings() + input_model_settings.model_path = args.model_dir + input_model_settings.params_count = int(0) + model_config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True) + input_model_settings.layer_count = model_config.num_hidden_layers + input_model_settings.num_k_heads = model_config.num_key_value_heads # model_config["num_key_value_heads"] + input_model_settings.k_head_dim = 128 + input_model_settings.bytes_per_params = 2 + input_model_settings.bytes_per_kv_cache_element = 2 + settings = sched_ext.Settings() + settings.model_name = model_name + settings.quant_type = "BF16" + settings.model_settings = input_model_settings + settings.page_size = args.page_size + settings.gpu_device_count = 1 # tp + settings.gpu_device_id = [i for i in range(settings.gpu_device_count)] + # settings.gpu_memory_size = args.cache_lens*576*2 + settings.gpu_memory_size = args.gpu_memory_size + settings.memory_utilization_percentage = args.utilization_percentage + max_batch_size = args.max_batch_size + chunk_size = args.chunk_size + + max_decode_batch_size = max_batch_size - 2 + + settings.max_batch_size = max_batch_size + settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2 + settings.sample_options = default_sample_options + settings.sched_metrics_port = args.sched_metrics_port + settings.gpu_only = args.memory_gpu_only + settings.use_self_defined_head_dim = False + settings.self_defined_head_dim = 576 + settings.full_kv_cache_on_each_gpu = True + settings.k_cache_on = True + settings.v_cache_on = True + + settings.kvc2_root_path = args.kvc2_disk_path + settings.kvc2_config_path = args.kvc2_config_dir + settings.memory_pool_size_GB = args.cpu_memory_size_GB + settings.evict_count = 40 + settings.kvc2_metrics_port = args.kvc2_metrics_port + settings.load_from_disk = False + settings.save_to_disk = True + + + settings.strategy_name = args.sched_strategy + + settings.auto_derive() + return settings + diff --git a/ktransformers/tests/test_speed.py b/ktransformers/tests/test_speed.py index 6f435b4..41848c1 100644 --- a/ktransformers/tests/test_speed.py +++ b/ktransformers/tests/test_speed.py @@ -149,7 +149,7 @@ if __name__ == "__main__": parser.add_argument("--model", type=str, default="DeepSeek-V3", help="Model name") parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048") parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL") - parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens") + parser.add_argument("--max_tokens", type=int, default=500, help="max decode tokens") args = parser.parse_args() SERVER_URL = args.api_url @@ -161,5 +161,7 @@ if __name__ == "__main__": prompt = ktansformer_prompt1024 * 2 elif args.prompt_lens == 4096: prompt = ktansformer_prompt1024 * 4 + + asyncio.run(main(args.concurrent, prompt, max_tokens, model)) diff --git a/ktransformers/util/custom_loader.py b/ktransformers/util/custom_loader.py index 003f93c..ee08e47 100644 --- a/ktransformers/util/custom_loader.py +++ b/ktransformers/util/custom_loader.py @@ -138,8 +138,12 @@ class SafeTensorLoader(ModelLoader): base_key = key # e.g. "model.layers.3.mlp.experts" experts_count = 0 + key_no_proj = False + if self.has_tensor(f"{base_key}.{experts_count}.up.weight"): + key_no_proj = True + # First, count how many experts we have by checking for expert 0's up_proj - while self.has_tensor(f"{base_key}.{experts_count}.up_proj.weight"): + while self.has_tensor(f"{base_key}.{experts_count}.up_proj.weight") or self.has_tensor(f"{base_key}.{experts_count}.up.weight"): experts_count += 1 if experts_count == 0: @@ -152,9 +156,15 @@ class SafeTensorLoader(ModelLoader): # Load all expert weights for expert_id in range(experts_count): - up_key = f"{base_key}.{expert_id}.up_proj.weight" - gate_key = f"{base_key}.{expert_id}.gate_proj.weight" - down_key = f"{base_key}.{expert_id}.down_proj.weight" + + if key_no_proj: + up_key = f"{base_key}.{expert_id}.up.weight" + gate_key = f"{base_key}.{expert_id}.gate.weight" + down_key = f"{base_key}.{expert_id}.down.weight" + else: + up_key = f"{base_key}.{expert_id}.up_proj.weight" + gate_key = f"{base_key}.{expert_id}.gate_proj.weight" + down_key = f"{base_key}.{expert_id}.down_proj.weight" up_tensor = self.load_tensor(up_key, device) gate_tensor = self.load_tensor(gate_key, device) diff --git a/pyproject.toml b/pyproject.toml index 9502c55..514516c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ dynamic = ["version"] dependencies = [ "torch >= 2.3.0", - "transformers == 4.51.3", + "transformers == 4.53.3", "fastapi >= 0.111.0", "uvicorn >= 0.30.1", "langchain >= 0.2.0", diff --git a/requirements-local_chat.txt b/requirements-local_chat.txt index 25afaef..8743136 100644 --- a/requirements-local_chat.txt +++ b/requirements-local_chat.txt @@ -1,5 +1,5 @@ fire -transformers==4.51.3 +transformers==4.53.3 numpy torch>=2.3.0 packaging