From 098602b08fdc92badf8331eab3deb6f56c5166f1 Mon Sep 17 00:00:00 2001 From: liam Date: Sun, 9 Feb 2025 22:39:01 +0800 Subject: [PATCH] :zap: v0.2 ongoing --- .gitignore | 5 +- Makefile | 2 +- README.md | 61 +++-- doc/en/DeepseekR1_V3_tutorial.md | 64 +++++ .../ktransformers_ext/CMakeLists.txt | 21 ++ .../ktransformers_ext/cpu_backend/backend.cpp | 17 ++ .../ktransformers_ext/cpu_backend/backend.h | 3 + .../operators/llamafile/moe.cpp | 75 +++++ .../operators/llamafile/moe.h | 6 + ktransformers/local_chat.py | 258 +++++++++++++++--- setup.py | 8 +- 11 files changed, 450 insertions(+), 70 deletions(-) create mode 100644 doc/en/DeepseekR1_V3_tutorial.md diff --git a/.gitignore b/.gitignore index c33a95d..d45e956 100644 --- a/.gitignore +++ b/.gitignore @@ -18,4 +18,7 @@ compile_commands.json ktransformers/server/local_store/ ktransformers/server_test1.db *.patch -img/ \ No newline at end of file +img/ +tmp1.txt +test_65_300_1536.txt +test.txt diff --git a/Makefile b/Makefile index dbf771d..f8633a9 100644 --- a/Makefile +++ b/Makefile @@ -17,5 +17,5 @@ dev_install: pip install -r requirements-local_chat.txt echo "Installing ktransformers" - KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . --no-build-isolation + KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . -v --no-build-isolation echo "Installation completed successfully" \ No newline at end of file diff --git a/README.md b/README.md index eb23bf8..8d92cb7 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin

🔥 Updates

+* **Fed 10, 2025**: Support DeepseekR1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to XXX speedup. The Detailed tutorial is [here](./doc/en/DeepseekR1_V3_tutorial.md) * **Aug 28, 2024**: Support 1M context under the InternLM2.5-7B-Chat-1M model, utilizing 24GB of VRAM and 150GB of DRAM. The detailed tutorial is [here](./doc/en/long_context_tutorial.md). * **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G. * **Aug 15, 2024**: Update detailed [TUTORIAL](doc/en/injection_tutorial.md) for injection and multi-GPU. @@ -31,6 +32,43 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin * **Aug 9, 2024**: Support windows native.

🔥 Show Cases

+ +
+

GPT-4/o1-level Local VSCode Copilot on a Desktop with only 24GB VRAM

+
+ +https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285 + +

+ +- **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 12GB VRAM and 382GB DRAM. + - Prefill Speed: + - KTransfermor: 54.21 (32 cores) → 74.362 (dual-socket, 2×32 cores) → xxx (optimized AMX-based MoE kernel, v3 only) → XXX (selectively using 6 experts, v3 only) + - Compared to 4.51 tokens/s in llama.cpp with 2×32 cores, achieving up to **XXX× speedup**. + - Decode Speed(tokens/s): + - KTransfermor: 8.73 (32 cores) → 11.26 (dual-socket, 2×32 cores) → 13.69 (selectively using 6 experts, v3 only) + - Compared to 4.51 tokens/s in llama.cpp with 2×32 cores, achieving up to **3.03× speedup**. + - Upcoming Open Source Release: + - AMX optimizations and selective expert activation will be open-sourced in v0.3. + - Currently available only in preview binary distribution, which can be found here. + +- **Local 236B DeepSeek-Coder-V2:** Running its Q4_K_M version using only 21GB VRAM and 136GB DRAM, attainable on a local desktop machine, which scores even better than GPT4-0613 in [BigCodeBench](https://huggingface.co/blog/leaderboard-bigcodebench). + +

+ + DeepSeek-Coder-V2 Score + +

+ +- **Faster Speed:** Achieving 126 tokens/s for 2K prompt prefill and 13.6 tokens/s for generation through MoE offloading and injecting advanced kernels from [Llamafile](https://github.com/Mozilla-Ocho/llamafile/tree/main) and [Marlin](https://github.com/IST-DASLab/marlin). +- **VSCode Integration:** Wrapped into an OpenAI and Ollama compatible API for seamless integration as a backend for [Tabby](https://github.com/TabbyML/tabby) and various other frontends. + +

+ +https://github.com/user-attachments/assets/4c6a8a38-05aa-497d-8eb1-3a5b3918429c + +

+

1M Context Local Inference on a Desktop with Only 24GB VRAM

@@ -54,30 +92,7 @@ https://github.com/user-attachments/assets/a865e5e4-bca3-401e-94b8-af3c080e6c12 * **Flexible Sparse Attention Framework**: Offers a flexible block sparse attention framework for CPU offloaded decoding. Compatible with SnapKV, Quest, and InfLLm. Further information is available [here](./doc/en/long_context_introduction.md). -

-

GPT-4-level Local VSCode Copilot on a Desktop with only 24GB VRAM

-
-https://github.com/user-attachments/assets/0b9fa2da-66f0-48eb-b4b9-f0e1f06f8927 - -

- -- **Local 236B DeepSeek-Coder-V2:** Running its Q4_K_M version using only 21GB VRAM and 136GB DRAM, attainable on a local desktop machine, which scores even better than GPT4-0613 in [BigCodeBench](https://huggingface.co/blog/leaderboard-bigcodebench). - -

- - DeepSeek-Coder-V2 Score - -

- -- **Faster Speed:** Achieving 126 tokens/s for 2K prompt prefill and 13.6 tokens/s for generation through MoE offloading and injecting advanced kernels from [Llamafile](https://github.com/Mozilla-Ocho/llamafile/tree/main) and [Marlin](https://github.com/IST-DASLab/marlin). -- **VSCode Integration:** Wrapped into an OpenAI and Ollama compatible API for seamless integration as a backend for [Tabby](https://github.com/TabbyML/tabby) and various other frontends. - -

- -https://github.com/user-attachments/assets/4c6a8a38-05aa-497d-8eb1-3a5b3918429c - -

More advanced features will coming soon, so stay tuned! diff --git a/doc/en/DeepseekR1_V3_tutorial.md b/doc/en/DeepseekR1_V3_tutorial.md new file mode 100644 index 0000000..1b1e6c7 --- /dev/null +++ b/doc/en/DeepseekR1_V3_tutorial.md @@ -0,0 +1,64 @@ +## prerequisites +We run our best performance tests on
+cpu: Intel(R) Xeon(R) Gold 6454S 1T DRAM(2 NUMA nodes)
+gpu: 4090D 24G VRAM
+## bench result +### V0.2 +#### settings +- model: DeepseekV3-q4km(int4)
+- CPU: cpu_model_name:Intel(R) Xeon(R) Gold 6454S, 32 cores per socket, 2 socket, 2numa nodes +- GPU: 4090D 24GVRAM +- we test after enough warm up! +#### memory consumption: + - single socket: 382G DRAM, 12G VRAM + - dual socket: 1T DRAM, 12G VRAM + +#### Benchmark Results + +"6 experts" case is part of v0.3's preview + +| Prompt
(500 tokens) | Dual socket Ktrans (6 experts) | Dual socket Ktrans (8 experts) | Single socket Ktrans (6 experts) | Single socket Ktrans (8 experts)| Llama (8 experts) | +| --- | --- | --- | --- | --- | --- | +| Prefill token/s | 97.32 | 82.94 | 65.14 | 54.21 | 10.31 | +| Decode token/s | 13.69 | 12.208 | 10.303 | 8.73 |4.51 | + +**The highest speedup reaches up to x3.03 in decoding and x9.44 in prefill.** + +## how to run +### v0.2 showcase +#### single socket version(32 cores) +our local_chat test command is: +``` shell +git clone https://github.com/kvcache-ai/ktransformers.git +cd ktransformers +numactl -N 1 -m 1 python ./ktransformers/local_chat.py --model_path --gguf_path --prompt_file --cpu_infer 33 --cache_lens 1536 + +``` +\ can be local or set from onlie hugging face like deepseek-ai/DeepSeek-V3. If onlie encounters connection problem, try use mirror(hf-mirror.com)
+\ can also be onlie, but as its large we recommend you download it and quantize the model to what you want.
+the command numactl -N 1 -m 1 aims to adoid data transfer between numa nodes. +### dual socket version(64 cores) +make suer before you install(use install.sh or `make dev_install`), setting the env var `USE_NUMA=1` by `export USE_NUMA=1`(if already installed, reinstall it with this env var set)
+our local_chat test command is: +``` shell +git clone https://github.com/kvcache-ai/ktransformers.git +cd ktransformers +export USE_NUMA=1 +make dev_install # or sh ./install.sh +python ./ktransformers/local_chat.py --model_path --gguf_path --prompt_file --cpu_infer 65 --cache_lens 1536 + +``` +The parameters meaning is the same. But As we use dual socket, so we set cpu_infer to 65. +## some explanations +1. From our perspective on DeepSeekV2, DeepSeekV3 and DeepSeekR1, +when we slightly decrease the activation experts num in inference, +the output quality doesn't change(within 1% accuracy drop),But the speed of decoding and prefill +is speed up about 30% which is inspiring. So our showcase makes use of this finding, +changing the activation experts of DeepSeekV3/R1 from 8 to 6.
+2. Also we want to make further use of our two NUMA nodes on Xeon Gold cpu. +To avoid the cost of data transfer between nodes, we "copy" the critical matrix on +both nodes which takes more memory consumption but accelerates the prefill and decoding process. +But this method takes huge memory and slow when loading weights, So be patient when loading +and monitor the memory usage.(we are considering to make this method as an option)
+3. the command args `--cpu_infer 65` specifies how many cores to use(it's ok that it exceeds the physical number, +but it's not the more the better. Adjust it slight lower to your actual number of cores)
diff --git a/ktransformers/ktransformers_ext/CMakeLists.txt b/ktransformers/ktransformers_ext/CMakeLists.txt index 1ef9823..d9ecd7a 100644 --- a/ktransformers/ktransformers_ext/CMakeLists.txt +++ b/ktransformers/ktransformers_ext/CMakeLists.txt @@ -230,3 +230,24 @@ elseif(UNIX) endif() target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so") endif() + +# Define the USE_NUMA option +option(USE_NUMA "Disable NUMA support" OFF) +# Check if the USE_NUMA environment variable is set +if(DEFINED ENV{USE_NUMA}) + set(USE_NUMA ON) +endif() +if (USE_NUMA) + message(STATUS "NUMA support is enabled") +else() + message(STATUS "NUMA support is disabled") +endif() + +find_library(NUMA_LIBRARY NAMES numa) +if (NUMA_LIBRARY AND USE_NUMA) + message(STATUS "NUMA library found: ${NUMA_LIBRARY} - enabling NUMA support") + target_link_libraries(${PROJECT_NAME} PRIVATE ${NUMA_LIBRARY}) + target_compile_definitions(${PROJECT_NAME} PRIVATE USE_NUMA) +else() + message(STATUS "NUMA library not found or user not set USE_NUMA - disabling NUMA support") +endif() diff --git a/ktransformers/ktransformers_ext/cpu_backend/backend.cpp b/ktransformers/ktransformers_ext/cpu_backend/backend.cpp index 16693f0..5980ba3 100644 --- a/ktransformers/ktransformers_ext/cpu_backend/backend.cpp +++ b/ktransformers/ktransformers_ext/cpu_backend/backend.cpp @@ -10,6 +10,13 @@ #include "backend.h" +#ifdef USE_NUMA +#include +#include + +thread_local int Backend::numa_node = -1; +#endif + thread_local int Backend::thread_local_id = -1; Backend::Backend(int max_thread_num) { @@ -74,6 +81,16 @@ void Backend::do_work_stealing_job(int task_num, } void Backend::process_tasks(int thread_id) { + + #ifdef USE_NUMA + if(numa_node == -1){ + numa_node = thread_id * numa_num_configured_nodes() / thread_num_; + struct bitmask* mask = numa_bitmask_alloc(numa_num_configured_nodes()); + numa_bitmask_setbit(mask, numa_node); + numa_bind(mask); + } + #endif + if (init_func_ != nullptr) { init_func_(thread_id); } diff --git a/ktransformers/ktransformers_ext/cpu_backend/backend.h b/ktransformers/ktransformers_ext/cpu_backend/backend.h index 80ff7f9..7a95f27 100644 --- a/ktransformers/ktransformers_ext/cpu_backend/backend.h +++ b/ktransformers/ktransformers_ext/cpu_backend/backend.h @@ -38,6 +38,9 @@ class Backend { void do_work_stealing_job(int, std::function, std::function, std::function); + #ifdef USE_NUMA + static thread_local int numa_node; + #endif static thread_local int thread_local_id; private: diff --git a/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp b/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp index a131b1f..35c144f 100644 --- a/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp +++ b/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp @@ -11,11 +11,41 @@ #include #include +#ifdef USE_NUMA +#include +#include +#endif + MOE::MOE(MOEConfig config) { config_ = config; gate_proj_ = config_.gate_proj; up_proj_ = config_.up_proj; down_proj_ = config_.down_proj; + + #ifdef USE_NUMA + int numa_nodes = numa_num_configured_nodes(); + gate_proj_numa_.resize(numa_nodes); + up_proj_numa_.resize(numa_nodes); + down_proj_numa_.resize(numa_nodes); + size_t exp_inter_hidden_mul_ = (size_t)config.expert_num * config.intermediate_size * config.hidden_size; + for (int i = 0; i < numa_nodes; i++) { + gate_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type), i); + up_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type), i); + down_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type), i); + if (!gate_proj_numa_[i]) { + std::cout << "Memory allocation failed for gate_proj_numa_ on node " << i << std::endl; + } + if (!up_proj_numa_[i]) { + std::cout << "Memory allocation failed for up_proj_numa_ on node " << i << std::endl; + } + if (!down_proj_numa_[i]) { + std::cout << "Memory allocation failed for down_proj_numa_ on node " << i << std::endl; + } + memcpy(gate_proj_numa_[i], gate_proj_, exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type)); + memcpy(up_proj_numa_[i], up_proj_, exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type)); + memcpy(down_proj_numa_[i], down_proj_, exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type)); + } + #endif std::vector> s_mem_requests; s_mem_requests.push_back({(void**)&s_input_fp32_, sizeof(float) * config_.hidden_size}); @@ -74,6 +104,15 @@ MOE::MOE(MOEConfig config) { MOE::~MOE() { shared_mem_buffer.dealloc(this); + + #ifdef USE_NUMA + int numa_nodes = numa_num_configured_nodes(); + for (int i = 0; i < numa_nodes; i++) { + numa_free(gate_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type)); + numa_free(up_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type)); + numa_free(down_proj_numa_[i], config_.expert_num * config_.hidden_size * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type)); + } + #endif } void MOE::warm_up(Backend* backend) { @@ -125,10 +164,22 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c int expert_idx = task_id / nth; uint64_t expert_id = expert_ids[expert_idx]; int ith = task_id % nth; + + #ifdef USE_NUMA + void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); + #else void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); + #endif + float* gate_output_ptr = s_gate_output_[expert_idx] + ith * config_.stride; llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); + + #ifdef USE_NUMA + void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); + #else void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); + #endif + float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride; llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { @@ -153,7 +204,13 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c } for (int expert_idx = 0; expert_idx < k; expert_idx++) { uint64_t expert_id = expert_ids[expert_idx]; + + #ifdef USE_NUMA + void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); + #else void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); + #endif + float* down_output_ptr = s_down_output_[expert_idx] + ith * config_.stride; llamafile_sgemm(config_.stride, 1, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), s_down_input_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { @@ -227,11 +284,23 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float* uint64_t expert_idx = task_id / nth; int ith = task_id % nth; void* gate_input_ptr = m_local_gate_input_ptr_[expert_idx]; + + #ifdef USE_NUMA + void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); + #else void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); + #endif + float* gate_output_ptr = m_local_gate_output_ptr_[expert_idx] + ith * stride; llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); void* up_input_ptr = m_local_up_input_ptr_[expert_idx]; + + #ifdef USE_NUMA + void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); + #else void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); + #endif + float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * stride; llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); for (int i = 0; i < m_local_num_[expert_idx]; i++) { @@ -249,7 +318,13 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float* uint64_t expert_idx = task_id / nth; int ith = task_id % nth; void* down_input_ptr = m_local_down_input_ptr_[expert_idx]; + + #ifdef USE_NUMA + void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); + #else void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); + #endif + float* down_output_ptr = m_local_down_output_ptr_[expert_idx] + ith * stride; llamafile_sgemm(stride, m_local_num_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); }, nullptr); diff --git a/ktransformers/ktransformers_ext/operators/llamafile/moe.h b/ktransformers/ktransformers_ext/operators/llamafile/moe.h index a1470aa..a39e21d 100644 --- a/ktransformers/ktransformers_ext/operators/llamafile/moe.h +++ b/ktransformers/ktransformers_ext/operators/llamafile/moe.h @@ -61,6 +61,12 @@ class MOE { void* up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)] void* down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)] + #ifdef USE_NUMA + std::vector gate_proj_numa_; // [numa_num, expert_num * intermediate_size * hidden_size ( /32 if quantized)] + std::vector up_proj_numa_; // [numa_num, expert_num * intermediate_size * hidden_size ( /32 if quantized)] + std::vector down_proj_numa_; // [numa_num, expert_num * hidden_size * intermediate_size ( /32 if quantized)] + #endif + float* s_input_fp32_; // [hidden_size] uint8_t* s_gate_input_; // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)] uint8_t* s_up_input_; // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)] diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 827d88f..5f17c21 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -1,24 +1,140 @@ +# """ +# Description : +# Author : Boxin Zhang, Azure-Tang +# Version : 0.1.0 +# Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +# """ + +# import asyncio +# import os +# import platform +# import sys +# project_dir = os.path.dirname(os.path.dirname(__file__)) +# sys.path.insert(0, project_dir) +# from ktransformers.server.args import ArgumentParser + + +# from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM +# from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM +# from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM +# from ktransformers.models.modeling_llama import LlamaForCausalLM +# from ktransformers.models.modeling_mixtral import MixtralForCausalLM +# from ktransformers.server.config.config import Config + +# custom_models = { +# "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, +# "DeepseekV3ForCausalLM": DeepseekV3ForCausalLM, +# "Qwen2MoeForCausalLM": Qwen2MoeForCausalLM, +# "LlamaForCausalLM": LlamaForCausalLM, +# "MixtralForCausalLM": MixtralForCausalLM, +# } + +# ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" +# default_optimize_rules = { +# "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", +# "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat.yaml", +# "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml", +# "LlamaForCausalLM": ktransformer_rules_dir + "Internlm2_5-7b-Chat-1m.yaml", +# "MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml", +# } + + +# def local_chat(): +# config = Config() +# arg_parser = ArgumentParser(config) +# # 初始化消息 +# arg_parser.parse_args() +# if config.backend_type == "transformers": +# from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface +# elif config.backend_type == "exllamav2": +# from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface +# elif config.backend_type == "ktransformers": +# from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface +# else: +# raise NotImplementedError(f"{config.backend_type} not implemented") +# interface = BackendInterface(config) + +# system = platform.system() +# if system == "Windows": +# os.system("cls") +# else: +# os.system("clear") +# # add a history chat content +# his_content = [] +# while True: +# content = input("Chat: ") +# if content.startswith('"""'): # prefix """ +# # multi lines input +# content = content[3:] + "\n" +# while True: +# line = input("") +# if line.endswith('"""'): +# # end multi lines input +# line = line[:-3] # suffix """ +# if line: +# content += line + "\n" +# break +# else: +# content += line + "\n" +# if content == "": +# if not config.prompt_file: +# content = "hi" +# else: +# content = open(config.prompt_file, "r").read() +# print("User: ", content) +# elif os.path.isfile(content): +# content = open(content, "r").read() +# print("User: ", content) +# messages = his_content + [{"role": "user", "content": content}] + +# async def async_inference(messages): +# generated = "" +# async for token in interface.inference(messages, "local_chat"): +# generated += token +# return generated + +# generated = asyncio.run(async_inference(messages)) +# his_content += [ +# {"role": "user", "content": content}, +# {"role": "assistant", "content": generated}, +# ] + + +# if __name__ == "__main__": +# local_chat() + + """ -Description : +Description : Author : Boxin Zhang, Azure-Tang Version : 0.1.0 -Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +Copyright (c) 2024 by KVCache.AI, All Rights Reserved. """ -import asyncio import os import platform import sys + project_dir = os.path.dirname(os.path.dirname(__file__)) sys.path.insert(0, project_dir) -from ktransformers.server.args import ArgumentParser - - +import torch +import logging +from transformers import ( + AutoTokenizer, + AutoConfig, + AutoModelForCausalLM, + GenerationConfig, + TextStreamer, +) +import json +import fire +from ktransformers.optimize.optimize import optimize_and_load_gguf from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM -from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM +from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM +from ktransformers.util.utils import prefill_and_generate from ktransformers.server.config.config import Config custom_models = { @@ -29,7 +145,9 @@ custom_models = { "MixtralForCausalLM": MixtralForCausalLM, } -ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" +ktransformer_rules_dir = ( + os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" +) default_optimize_rules = { "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat.yaml", @@ -39,28 +157,85 @@ default_optimize_rules = { } -def local_chat(): - config = Config() - arg_parser = ArgumentParser(config) - # 初始化消息 - arg_parser.parse_args() - if config.backend_type == "transformers": - from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface - elif config.backend_type == "exllamav2": - from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface - elif config.backend_type == "ktransformers": - from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface +def local_chat( + model_path: str | None = None, + optimize_rule_path: str = None, + gguf_path: str | None = None, + max_new_tokens: int = 1000, + cpu_infer: int = Config().cpu_infer, + use_cuda_graph: bool = True, + prompt_file : str | None = None, + mode: str = "normal", +): + + + torch.set_grad_enabled(False) + + Config().cpu_infer = cpu_infer + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + if mode == 'long_context': + assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode" + torch.set_default_dtype(torch.float16) else: - raise NotImplementedError(f"{config.backend_type} not implemented") - interface = BackendInterface(config) + torch.set_default_dtype(config.torch_dtype) + + with torch.device("meta"): + if config.architectures[0] in custom_models: + print("using custom modeling_xxx.py.") + if ( + "Qwen2Moe" in config.architectures[0] + ): # Qwen2Moe must use flash_attention_2 to avoid overflow. + config._attn_implementation = "flash_attention_2" + if "Llama" in config.architectures[0]: + config._attn_implementation = "eager" + if "Mixtral" in config.architectures[0]: + config._attn_implementation = "flash_attention_2" + + model = custom_models[config.architectures[0]](config) + else: + model = AutoModelForCausalLM.from_config( + config, trust_remote_code=True, attn_implementation="flash_attention_2" + ) + + if optimize_rule_path is None: + if config.architectures[0] in default_optimize_rules: + print("using default_optimize_rule for", config.architectures[0]) + optimize_rule_path = default_optimize_rules[config.architectures[0]] + else: + optimize_rule_path = input( + "please input the path of your rule file(yaml file containing optimize rules):" + ) + + if gguf_path is None: + gguf_path = input( + "please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):" + ) + optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config) + + try: + model.generation_config = GenerationConfig.from_pretrained(model_path) + except: + gen_config = GenerationConfig( + max_length=128, + temperature=0.7, + top_p=0.9, + do_sample=True + ) + model.generation_config = gen_config + # model.generation_config = GenerationConfig.from_pretrained(model_path) + if model.generation_config.pad_token_id is None: + model.generation_config.pad_token_id = model.generation_config.eos_token_id + model.eval() + logging.basicConfig(level=logging.INFO) system = platform.system() if system == "Windows": os.system("cls") else: os.system("clear") - # add a history chat content - his_content = [] + while True: content = input("Chat: ") if content.startswith('"""'): # prefix """ @@ -76,29 +251,28 @@ def local_chat(): break else: content += line + "\n" + if content == "": - if not config.prompt_file: - content = "hi" + if prompt_file != None: + content = open(prompt_file, "r").read() else: - content = open(config.prompt_file, "r").read() - print("User: ", content) + content = "Please write a piece of quicksort code in C++." elif os.path.isfile(content): content = open(content, "r").read() - print("User: ", content) - messages = his_content + [{"role": "user", "content": content}] - - async def async_inference(messages): - generated = "" - async for token in interface.inference(messages, "local_chat"): - generated += token - return generated - - generated = asyncio.run(async_inference(messages)) - his_content += [ - {"role": "user", "content": content}, - {"role": "assistant", "content": generated}, - ] + messages = [{"role": "user", "content": content}] + input_tensor = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, return_tensors="pt" + ) + if mode == 'long_context': + assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ + "please change max_seq_len in ~/.ktransformers/config.yaml" + torch.set_default_dtype( + torch.bfloat16 + ) # TODO: Remove this, replace dtype using config + generated = prefill_and_generate( + model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode + ) if __name__ == "__main__": - local_chat() + fire.Fire(local_chat) \ No newline at end of file diff --git a/setup.py b/setup.py index 2a09b48..d24db14 100644 --- a/setup.py +++ b/setup.py @@ -278,13 +278,15 @@ class CMakeBuild(BuildExtension): if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: if hasattr(self, "parallel") and self.parallel: build_args += [f"-j{self.parallel}"] - + print("CMake args:", cmake_args) build_temp = Path(ext.sourcedir) / "build" if not build_temp.exists(): build_temp.mkdir(parents=True) - subprocess.run( - ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True + result = subprocess.run( + ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True , capture_output=True ) + print("Standard output:", result.stdout) + print("Standard error:", result.stderr) subprocess.run( ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True )