diff --git a/.gitignore b/.gitignore
index c33a95d..d45e956 100644
--- a/.gitignore
+++ b/.gitignore
@@ -18,4 +18,7 @@ compile_commands.json
ktransformers/server/local_store/
ktransformers/server_test1.db
*.patch
-img/
\ No newline at end of file
+img/
+tmp1.txt
+test_65_300_1536.txt
+test.txt
diff --git a/Makefile b/Makefile
index dbf771d..f8633a9 100644
--- a/Makefile
+++ b/Makefile
@@ -17,5 +17,5 @@ dev_install:
pip install -r requirements-local_chat.txt
echo "Installing ktransformers"
- KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . --no-build-isolation
+ KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . -v --no-build-isolation
echo "Installation completed successfully"
\ No newline at end of file
diff --git a/README.md b/README.md
index eb23bf8..8d92cb7 100644
--- a/README.md
+++ b/README.md
@@ -23,6 +23,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
🔥 Updates
+* **Fed 10, 2025**: Support DeepseekR1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to XXX speedup. The Detailed tutorial is [here](./doc/en/DeepseekR1_V3_tutorial.md)
* **Aug 28, 2024**: Support 1M context under the InternLM2.5-7B-Chat-1M model, utilizing 24GB of VRAM and 150GB of DRAM. The detailed tutorial is [here](./doc/en/long_context_tutorial.md).
* **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G.
* **Aug 15, 2024**: Update detailed [TUTORIAL](doc/en/injection_tutorial.md) for injection and multi-GPU.
@@ -31,6 +32,43 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
* **Aug 9, 2024**: Support windows native.
🔥 Show Cases
+
+
+
GPT-4/o1-level Local VSCode Copilot on a Desktop with only 24GB VRAM
+
+
+https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285
+
+
+
+- **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 12GB VRAM and 382GB DRAM.
+ - Prefill Speed:
+ - KTransfermor: 54.21 (32 cores) → 74.362 (dual-socket, 2×32 cores) → xxx (optimized AMX-based MoE kernel, v3 only) → XXX (selectively using 6 experts, v3 only)
+ - Compared to 4.51 tokens/s in llama.cpp with 2×32 cores, achieving up to **XXX× speedup**.
+ - Decode Speed(tokens/s):
+ - KTransfermor: 8.73 (32 cores) → 11.26 (dual-socket, 2×32 cores) → 13.69 (selectively using 6 experts, v3 only)
+ - Compared to 4.51 tokens/s in llama.cpp with 2×32 cores, achieving up to **3.03× speedup**.
+ - Upcoming Open Source Release:
+ - AMX optimizations and selective expert activation will be open-sourced in v0.3.
+ - Currently available only in preview binary distribution, which can be found here.
+
+- **Local 236B DeepSeek-Coder-V2:** Running its Q4_K_M version using only 21GB VRAM and 136GB DRAM, attainable on a local desktop machine, which scores even better than GPT4-0613 in [BigCodeBench](https://huggingface.co/blog/leaderboard-bigcodebench).
+
+
+
+
+
+
+
+- **Faster Speed:** Achieving 126 tokens/s for 2K prompt prefill and 13.6 tokens/s for generation through MoE offloading and injecting advanced kernels from [Llamafile](https://github.com/Mozilla-Ocho/llamafile/tree/main) and [Marlin](https://github.com/IST-DASLab/marlin).
+- **VSCode Integration:** Wrapped into an OpenAI and Ollama compatible API for seamless integration as a backend for [Tabby](https://github.com/TabbyML/tabby) and various other frontends.
+
+
+
+https://github.com/user-attachments/assets/4c6a8a38-05aa-497d-8eb1-3a5b3918429c
+
+
+
1M Context Local Inference on a Desktop with Only 24GB VRAM
@@ -54,30 +92,7 @@ https://github.com/user-attachments/assets/a865e5e4-bca3-401e-94b8-af3c080e6c12
* **Flexible Sparse Attention Framework**: Offers a flexible block sparse attention framework for CPU offloaded decoding. Compatible with SnapKV, Quest, and InfLLm. Further information is available [here](./doc/en/long_context_introduction.md).
-
-
GPT-4-level Local VSCode Copilot on a Desktop with only 24GB VRAM
-
-https://github.com/user-attachments/assets/0b9fa2da-66f0-48eb-b4b9-f0e1f06f8927
-
-
-
-- **Local 236B DeepSeek-Coder-V2:** Running its Q4_K_M version using only 21GB VRAM and 136GB DRAM, attainable on a local desktop machine, which scores even better than GPT4-0613 in [BigCodeBench](https://huggingface.co/blog/leaderboard-bigcodebench).
-
-
-
-
-
-
-
-- **Faster Speed:** Achieving 126 tokens/s for 2K prompt prefill and 13.6 tokens/s for generation through MoE offloading and injecting advanced kernels from [Llamafile](https://github.com/Mozilla-Ocho/llamafile/tree/main) and [Marlin](https://github.com/IST-DASLab/marlin).
-- **VSCode Integration:** Wrapped into an OpenAI and Ollama compatible API for seamless integration as a backend for [Tabby](https://github.com/TabbyML/tabby) and various other frontends.
-
-
-
-https://github.com/user-attachments/assets/4c6a8a38-05aa-497d-8eb1-3a5b3918429c
-
-
More advanced features will coming soon, so stay tuned!
diff --git a/doc/en/DeepseekR1_V3_tutorial.md b/doc/en/DeepseekR1_V3_tutorial.md
new file mode 100644
index 0000000..1b1e6c7
--- /dev/null
+++ b/doc/en/DeepseekR1_V3_tutorial.md
@@ -0,0 +1,64 @@
+## prerequisites
+We run our best performance tests on
+cpu: Intel(R) Xeon(R) Gold 6454S 1T DRAM(2 NUMA nodes)
+gpu: 4090D 24G VRAM
+## bench result
+### V0.2
+#### settings
+- model: DeepseekV3-q4km(int4)
+- CPU: cpu_model_name:Intel(R) Xeon(R) Gold 6454S, 32 cores per socket, 2 socket, 2numa nodes
+- GPU: 4090D 24GVRAM
+- we test after enough warm up!
+#### memory consumption:
+ - single socket: 382G DRAM, 12G VRAM
+ - dual socket: 1T DRAM, 12G VRAM
+
+#### Benchmark Results
+
+"6 experts" case is part of v0.3's preview
+
+| Prompt
(500 tokens) | Dual socket Ktrans (6 experts) | Dual socket Ktrans (8 experts) | Single socket Ktrans (6 experts) | Single socket Ktrans (8 experts)| Llama (8 experts) |
+| --- | --- | --- | --- | --- | --- |
+| Prefill token/s | 97.32 | 82.94 | 65.14 | 54.21 | 10.31 |
+| Decode token/s | 13.69 | 12.208 | 10.303 | 8.73 |4.51 |
+
+**The highest speedup reaches up to x3.03 in decoding and x9.44 in prefill.**
+
+## how to run
+### v0.2 showcase
+#### single socket version(32 cores)
+our local_chat test command is:
+``` shell
+git clone https://github.com/kvcache-ai/ktransformers.git
+cd ktransformers
+numactl -N 1 -m 1 python ./ktransformers/local_chat.py --model_path --gguf_path --prompt_file --cpu_infer 33 --cache_lens 1536
+
+```
+\ can be local or set from onlie hugging face like deepseek-ai/DeepSeek-V3. If onlie encounters connection problem, try use mirror(hf-mirror.com)
+\ can also be onlie, but as its large we recommend you download it and quantize the model to what you want.
+the command numactl -N 1 -m 1 aims to adoid data transfer between numa nodes.
+### dual socket version(64 cores)
+make suer before you install(use install.sh or `make dev_install`), setting the env var `USE_NUMA=1` by `export USE_NUMA=1`(if already installed, reinstall it with this env var set)
+our local_chat test command is:
+``` shell
+git clone https://github.com/kvcache-ai/ktransformers.git
+cd ktransformers
+export USE_NUMA=1
+make dev_install # or sh ./install.sh
+python ./ktransformers/local_chat.py --model_path --gguf_path --prompt_file --cpu_infer 65 --cache_lens 1536
+
+```
+The parameters meaning is the same. But As we use dual socket, so we set cpu_infer to 65.
+## some explanations
+1. From our perspective on DeepSeekV2, DeepSeekV3 and DeepSeekR1,
+when we slightly decrease the activation experts num in inference,
+the output quality doesn't change(within 1% accuracy drop),But the speed of decoding and prefill
+is speed up about 30% which is inspiring. So our showcase makes use of this finding,
+changing the activation experts of DeepSeekV3/R1 from 8 to 6.
+2. Also we want to make further use of our two NUMA nodes on Xeon Gold cpu.
+To avoid the cost of data transfer between nodes, we "copy" the critical matrix on
+both nodes which takes more memory consumption but accelerates the prefill and decoding process.
+But this method takes huge memory and slow when loading weights, So be patient when loading
+and monitor the memory usage.(we are considering to make this method as an option)
+3. the command args `--cpu_infer 65` specifies how many cores to use(it's ok that it exceeds the physical number,
+but it's not the more the better. Adjust it slight lower to your actual number of cores)
diff --git a/ktransformers/ktransformers_ext/CMakeLists.txt b/ktransformers/ktransformers_ext/CMakeLists.txt
index 1ef9823..d9ecd7a 100644
--- a/ktransformers/ktransformers_ext/CMakeLists.txt
+++ b/ktransformers/ktransformers_ext/CMakeLists.txt
@@ -230,3 +230,24 @@ elseif(UNIX)
endif()
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
endif()
+
+# Define the USE_NUMA option
+option(USE_NUMA "Disable NUMA support" OFF)
+# Check if the USE_NUMA environment variable is set
+if(DEFINED ENV{USE_NUMA})
+ set(USE_NUMA ON)
+endif()
+if (USE_NUMA)
+ message(STATUS "NUMA support is enabled")
+else()
+ message(STATUS "NUMA support is disabled")
+endif()
+
+find_library(NUMA_LIBRARY NAMES numa)
+if (NUMA_LIBRARY AND USE_NUMA)
+ message(STATUS "NUMA library found: ${NUMA_LIBRARY} - enabling NUMA support")
+ target_link_libraries(${PROJECT_NAME} PRIVATE ${NUMA_LIBRARY})
+ target_compile_definitions(${PROJECT_NAME} PRIVATE USE_NUMA)
+else()
+ message(STATUS "NUMA library not found or user not set USE_NUMA - disabling NUMA support")
+endif()
diff --git a/ktransformers/ktransformers_ext/cpu_backend/backend.cpp b/ktransformers/ktransformers_ext/cpu_backend/backend.cpp
index 16693f0..5980ba3 100644
--- a/ktransformers/ktransformers_ext/cpu_backend/backend.cpp
+++ b/ktransformers/ktransformers_ext/cpu_backend/backend.cpp
@@ -10,6 +10,13 @@
#include "backend.h"
+#ifdef USE_NUMA
+#include
+#include
+
+thread_local int Backend::numa_node = -1;
+#endif
+
thread_local int Backend::thread_local_id = -1;
Backend::Backend(int max_thread_num) {
@@ -74,6 +81,16 @@ void Backend::do_work_stealing_job(int task_num,
}
void Backend::process_tasks(int thread_id) {
+
+ #ifdef USE_NUMA
+ if(numa_node == -1){
+ numa_node = thread_id * numa_num_configured_nodes() / thread_num_;
+ struct bitmask* mask = numa_bitmask_alloc(numa_num_configured_nodes());
+ numa_bitmask_setbit(mask, numa_node);
+ numa_bind(mask);
+ }
+ #endif
+
if (init_func_ != nullptr) {
init_func_(thread_id);
}
diff --git a/ktransformers/ktransformers_ext/cpu_backend/backend.h b/ktransformers/ktransformers_ext/cpu_backend/backend.h
index 80ff7f9..7a95f27 100644
--- a/ktransformers/ktransformers_ext/cpu_backend/backend.h
+++ b/ktransformers/ktransformers_ext/cpu_backend/backend.h
@@ -38,6 +38,9 @@ class Backend {
void do_work_stealing_job(int, std::function,
std::function,
std::function);
+ #ifdef USE_NUMA
+ static thread_local int numa_node;
+ #endif
static thread_local int thread_local_id;
private:
diff --git a/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp b/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp
index a131b1f..35c144f 100644
--- a/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp
+++ b/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp
@@ -11,11 +11,41 @@
#include
#include
+#ifdef USE_NUMA
+#include
+#include
+#endif
+
MOE::MOE(MOEConfig config) {
config_ = config;
gate_proj_ = config_.gate_proj;
up_proj_ = config_.up_proj;
down_proj_ = config_.down_proj;
+
+ #ifdef USE_NUMA
+ int numa_nodes = numa_num_configured_nodes();
+ gate_proj_numa_.resize(numa_nodes);
+ up_proj_numa_.resize(numa_nodes);
+ down_proj_numa_.resize(numa_nodes);
+ size_t exp_inter_hidden_mul_ = (size_t)config.expert_num * config.intermediate_size * config.hidden_size;
+ for (int i = 0; i < numa_nodes; i++) {
+ gate_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type), i);
+ up_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type), i);
+ down_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type), i);
+ if (!gate_proj_numa_[i]) {
+ std::cout << "Memory allocation failed for gate_proj_numa_ on node " << i << std::endl;
+ }
+ if (!up_proj_numa_[i]) {
+ std::cout << "Memory allocation failed for up_proj_numa_ on node " << i << std::endl;
+ }
+ if (!down_proj_numa_[i]) {
+ std::cout << "Memory allocation failed for down_proj_numa_ on node " << i << std::endl;
+ }
+ memcpy(gate_proj_numa_[i], gate_proj_, exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type));
+ memcpy(up_proj_numa_[i], up_proj_, exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type));
+ memcpy(down_proj_numa_[i], down_proj_, exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type));
+ }
+ #endif
std::vector> s_mem_requests;
s_mem_requests.push_back({(void**)&s_input_fp32_, sizeof(float) * config_.hidden_size});
@@ -74,6 +104,15 @@ MOE::MOE(MOEConfig config) {
MOE::~MOE() {
shared_mem_buffer.dealloc(this);
+
+ #ifdef USE_NUMA
+ int numa_nodes = numa_num_configured_nodes();
+ for (int i = 0; i < numa_nodes; i++) {
+ numa_free(gate_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type));
+ numa_free(up_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type));
+ numa_free(down_proj_numa_[i], config_.expert_num * config_.hidden_size * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type));
+ }
+ #endif
}
void MOE::warm_up(Backend* backend) {
@@ -125,10 +164,22 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c
int expert_idx = task_id / nth;
uint64_t expert_id = expert_ids[expert_idx];
int ith = task_id % nth;
+
+ #ifdef USE_NUMA
+ void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
+ #else
void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
+ #endif
+
float* gate_output_ptr = s_gate_output_[expert_idx] + ith * config_.stride;
llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
+
+ #ifdef USE_NUMA
+ void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
+ #else
void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
+ #endif
+
float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride;
llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
@@ -153,7 +204,13 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c
}
for (int expert_idx = 0; expert_idx < k; expert_idx++) {
uint64_t expert_id = expert_ids[expert_idx];
+
+ #ifdef USE_NUMA
+ void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
+ #else
void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
+ #endif
+
float* down_output_ptr = s_down_output_[expert_idx] + ith * config_.stride;
llamafile_sgemm(config_.stride, 1, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), s_down_input_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
@@ -227,11 +284,23 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
uint64_t expert_idx = task_id / nth;
int ith = task_id % nth;
void* gate_input_ptr = m_local_gate_input_ptr_[expert_idx];
+
+ #ifdef USE_NUMA
+ void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
+ #else
void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
+ #endif
+
float* gate_output_ptr = m_local_gate_output_ptr_[expert_idx] + ith * stride;
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
void* up_input_ptr = m_local_up_input_ptr_[expert_idx];
+
+ #ifdef USE_NUMA
+ void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
+ #else
void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
+ #endif
+
float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * stride;
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
@@ -249,7 +318,13 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
uint64_t expert_idx = task_id / nth;
int ith = task_id % nth;
void* down_input_ptr = m_local_down_input_ptr_[expert_idx];
+
+ #ifdef USE_NUMA
+ void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
+ #else
void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
+ #endif
+
float* down_output_ptr = m_local_down_output_ptr_[expert_idx] + ith * stride;
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
}, nullptr);
diff --git a/ktransformers/ktransformers_ext/operators/llamafile/moe.h b/ktransformers/ktransformers_ext/operators/llamafile/moe.h
index a1470aa..a39e21d 100644
--- a/ktransformers/ktransformers_ext/operators/llamafile/moe.h
+++ b/ktransformers/ktransformers_ext/operators/llamafile/moe.h
@@ -61,6 +61,12 @@ class MOE {
void* up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
void* down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)]
+ #ifdef USE_NUMA
+ std::vector gate_proj_numa_; // [numa_num, expert_num * intermediate_size * hidden_size ( /32 if quantized)]
+ std::vector up_proj_numa_; // [numa_num, expert_num * intermediate_size * hidden_size ( /32 if quantized)]
+ std::vector down_proj_numa_; // [numa_num, expert_num * hidden_size * intermediate_size ( /32 if quantized)]
+ #endif
+
float* s_input_fp32_; // [hidden_size]
uint8_t* s_gate_input_; // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]
uint8_t* s_up_input_; // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]
diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py
index 827d88f..5f17c21 100644
--- a/ktransformers/local_chat.py
+++ b/ktransformers/local_chat.py
@@ -1,24 +1,140 @@
+# """
+# Description :
+# Author : Boxin Zhang, Azure-Tang
+# Version : 0.1.0
+# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
+# """
+
+# import asyncio
+# import os
+# import platform
+# import sys
+# project_dir = os.path.dirname(os.path.dirname(__file__))
+# sys.path.insert(0, project_dir)
+# from ktransformers.server.args import ArgumentParser
+
+
+# from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
+# from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
+# from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
+# from ktransformers.models.modeling_llama import LlamaForCausalLM
+# from ktransformers.models.modeling_mixtral import MixtralForCausalLM
+# from ktransformers.server.config.config import Config
+
+# custom_models = {
+# "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
+# "DeepseekV3ForCausalLM": DeepseekV3ForCausalLM,
+# "Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
+# "LlamaForCausalLM": LlamaForCausalLM,
+# "MixtralForCausalLM": MixtralForCausalLM,
+# }
+
+# ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
+# default_optimize_rules = {
+# "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
+# "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat.yaml",
+# "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml",
+# "LlamaForCausalLM": ktransformer_rules_dir + "Internlm2_5-7b-Chat-1m.yaml",
+# "MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml",
+# }
+
+
+# def local_chat():
+# config = Config()
+# arg_parser = ArgumentParser(config)
+# # 初始化消息
+# arg_parser.parse_args()
+# if config.backend_type == "transformers":
+# from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
+# elif config.backend_type == "exllamav2":
+# from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface
+# elif config.backend_type == "ktransformers":
+# from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface
+# else:
+# raise NotImplementedError(f"{config.backend_type} not implemented")
+# interface = BackendInterface(config)
+
+# system = platform.system()
+# if system == "Windows":
+# os.system("cls")
+# else:
+# os.system("clear")
+# # add a history chat content
+# his_content = []
+# while True:
+# content = input("Chat: ")
+# if content.startswith('"""'): # prefix """
+# # multi lines input
+# content = content[3:] + "\n"
+# while True:
+# line = input("")
+# if line.endswith('"""'):
+# # end multi lines input
+# line = line[:-3] # suffix """
+# if line:
+# content += line + "\n"
+# break
+# else:
+# content += line + "\n"
+# if content == "":
+# if not config.prompt_file:
+# content = "hi"
+# else:
+# content = open(config.prompt_file, "r").read()
+# print("User: ", content)
+# elif os.path.isfile(content):
+# content = open(content, "r").read()
+# print("User: ", content)
+# messages = his_content + [{"role": "user", "content": content}]
+
+# async def async_inference(messages):
+# generated = ""
+# async for token in interface.inference(messages, "local_chat"):
+# generated += token
+# return generated
+
+# generated = asyncio.run(async_inference(messages))
+# his_content += [
+# {"role": "user", "content": content},
+# {"role": "assistant", "content": generated},
+# ]
+
+
+# if __name__ == "__main__":
+# local_chat()
+
+
"""
-Description :
+Description :
Author : Boxin Zhang, Azure-Tang
Version : 0.1.0
-Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
+Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
-import asyncio
import os
import platform
import sys
+
project_dir = os.path.dirname(os.path.dirname(__file__))
sys.path.insert(0, project_dir)
-from ktransformers.server.args import ArgumentParser
-
-
+import torch
+import logging
+from transformers import (
+ AutoTokenizer,
+ AutoConfig,
+ AutoModelForCausalLM,
+ GenerationConfig,
+ TextStreamer,
+)
+import json
+import fire
+from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
-from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
+from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
+from ktransformers.util.utils import prefill_and_generate
from ktransformers.server.config.config import Config
custom_models = {
@@ -29,7 +145,9 @@ custom_models = {
"MixtralForCausalLM": MixtralForCausalLM,
}
-ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
+ktransformer_rules_dir = (
+ os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
+)
default_optimize_rules = {
"DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat.yaml",
@@ -39,28 +157,85 @@ default_optimize_rules = {
}
-def local_chat():
- config = Config()
- arg_parser = ArgumentParser(config)
- # 初始化消息
- arg_parser.parse_args()
- if config.backend_type == "transformers":
- from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
- elif config.backend_type == "exllamav2":
- from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface
- elif config.backend_type == "ktransformers":
- from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface
+def local_chat(
+ model_path: str | None = None,
+ optimize_rule_path: str = None,
+ gguf_path: str | None = None,
+ max_new_tokens: int = 1000,
+ cpu_infer: int = Config().cpu_infer,
+ use_cuda_graph: bool = True,
+ prompt_file : str | None = None,
+ mode: str = "normal",
+):
+
+
+ torch.set_grad_enabled(False)
+
+ Config().cpu_infer = cpu_infer
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
+ if mode == 'long_context':
+ assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
+ torch.set_default_dtype(torch.float16)
else:
- raise NotImplementedError(f"{config.backend_type} not implemented")
- interface = BackendInterface(config)
+ torch.set_default_dtype(config.torch_dtype)
+
+ with torch.device("meta"):
+ if config.architectures[0] in custom_models:
+ print("using custom modeling_xxx.py.")
+ if (
+ "Qwen2Moe" in config.architectures[0]
+ ): # Qwen2Moe must use flash_attention_2 to avoid overflow.
+ config._attn_implementation = "flash_attention_2"
+ if "Llama" in config.architectures[0]:
+ config._attn_implementation = "eager"
+ if "Mixtral" in config.architectures[0]:
+ config._attn_implementation = "flash_attention_2"
+
+ model = custom_models[config.architectures[0]](config)
+ else:
+ model = AutoModelForCausalLM.from_config(
+ config, trust_remote_code=True, attn_implementation="flash_attention_2"
+ )
+
+ if optimize_rule_path is None:
+ if config.architectures[0] in default_optimize_rules:
+ print("using default_optimize_rule for", config.architectures[0])
+ optimize_rule_path = default_optimize_rules[config.architectures[0]]
+ else:
+ optimize_rule_path = input(
+ "please input the path of your rule file(yaml file containing optimize rules):"
+ )
+
+ if gguf_path is None:
+ gguf_path = input(
+ "please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
+ )
+ optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config)
+
+ try:
+ model.generation_config = GenerationConfig.from_pretrained(model_path)
+ except:
+ gen_config = GenerationConfig(
+ max_length=128,
+ temperature=0.7,
+ top_p=0.9,
+ do_sample=True
+ )
+ model.generation_config = gen_config
+ # model.generation_config = GenerationConfig.from_pretrained(model_path)
+ if model.generation_config.pad_token_id is None:
+ model.generation_config.pad_token_id = model.generation_config.eos_token_id
+ model.eval()
+ logging.basicConfig(level=logging.INFO)
system = platform.system()
if system == "Windows":
os.system("cls")
else:
os.system("clear")
- # add a history chat content
- his_content = []
+
while True:
content = input("Chat: ")
if content.startswith('"""'): # prefix """
@@ -76,29 +251,28 @@ def local_chat():
break
else:
content += line + "\n"
+
if content == "":
- if not config.prompt_file:
- content = "hi"
+ if prompt_file != None:
+ content = open(prompt_file, "r").read()
else:
- content = open(config.prompt_file, "r").read()
- print("User: ", content)
+ content = "Please write a piece of quicksort code in C++."
elif os.path.isfile(content):
content = open(content, "r").read()
- print("User: ", content)
- messages = his_content + [{"role": "user", "content": content}]
-
- async def async_inference(messages):
- generated = ""
- async for token in interface.inference(messages, "local_chat"):
- generated += token
- return generated
-
- generated = asyncio.run(async_inference(messages))
- his_content += [
- {"role": "user", "content": content},
- {"role": "assistant", "content": generated},
- ]
+ messages = [{"role": "user", "content": content}]
+ input_tensor = tokenizer.apply_chat_template(
+ messages, add_generation_prompt=True, return_tensors="pt"
+ )
+ if mode == 'long_context':
+ assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
+ "please change max_seq_len in ~/.ktransformers/config.yaml"
+ torch.set_default_dtype(
+ torch.bfloat16
+ ) # TODO: Remove this, replace dtype using config
+ generated = prefill_and_generate(
+ model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode
+ )
if __name__ == "__main__":
- local_chat()
+ fire.Fire(local_chat)
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 2a09b48..d24db14 100644
--- a/setup.py
+++ b/setup.py
@@ -278,13 +278,15 @@ class CMakeBuild(BuildExtension):
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
if hasattr(self, "parallel") and self.parallel:
build_args += [f"-j{self.parallel}"]
-
+ print("CMake args:", cmake_args)
build_temp = Path(ext.sourcedir) / "build"
if not build_temp.exists():
build_temp.mkdir(parents=True)
- subprocess.run(
- ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
+ result = subprocess.run(
+ ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True , capture_output=True
)
+ print("Standard output:", result.stdout)
+ print("Standard error:", result.stderr)
subprocess.run(
["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
)