v0.2 ongoing

This commit is contained in:
liam 2025-02-09 22:39:01 +08:00
parent bf1d413be0
commit 098602b08f
11 changed files with 450 additions and 70 deletions

5
.gitignore vendored
View file

@ -18,4 +18,7 @@ compile_commands.json
ktransformers/server/local_store/ ktransformers/server/local_store/
ktransformers/server_test1.db ktransformers/server_test1.db
*.patch *.patch
img/ img/
tmp1.txt
test_65_300_1536.txt
test.txt

View file

@ -17,5 +17,5 @@ dev_install:
pip install -r requirements-local_chat.txt pip install -r requirements-local_chat.txt
echo "Installing ktransformers" echo "Installing ktransformers"
KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . --no-build-isolation KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . -v --no-build-isolation
echo "Installation completed successfully" echo "Installation completed successfully"

View file

@ -23,6 +23,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
<h2 id="Updates">🔥 Updates</h2> <h2 id="Updates">🔥 Updates</h2>
* **Fed 10, 2025**: Support DeepseekR1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to XXX speedup. The Detailed tutorial is [here](./doc/en/DeepseekR1_V3_tutorial.md)
* **Aug 28, 2024**: Support 1M context under the InternLM2.5-7B-Chat-1M model, utilizing 24GB of VRAM and 150GB of DRAM. The detailed tutorial is [here](./doc/en/long_context_tutorial.md). * **Aug 28, 2024**: Support 1M context under the InternLM2.5-7B-Chat-1M model, utilizing 24GB of VRAM and 150GB of DRAM. The detailed tutorial is [here](./doc/en/long_context_tutorial.md).
* **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G. * **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G.
* **Aug 15, 2024**: Update detailed [TUTORIAL](doc/en/injection_tutorial.md) for injection and multi-GPU. * **Aug 15, 2024**: Update detailed [TUTORIAL](doc/en/injection_tutorial.md) for injection and multi-GPU.
@ -31,6 +32,43 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
* **Aug 9, 2024**: Support windows native. * **Aug 9, 2024**: Support windows native.
<h2 id="show-cases">🔥 Show Cases</h2> <h2 id="show-cases">🔥 Show Cases</h2>
<div>
<h3>GPT-4/o1-level Local VSCode Copilot on a Desktop with only 24GB VRAM</h3>
</div>
https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285
</p>
- **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 12GB VRAM and 382GB DRAM.
- Prefill Speed:
- KTransfermor: 54.21 (32 cores) → 74.362 (dual-socket, 2×32 cores) → xxx (optimized AMX-based MoE kernel, v3 only) → XXX (selectively using 6 experts, v3 only)
- Compared to 4.51 tokens/s in llama.cpp with 2×32 cores, achieving up to **XXX× speedup**.
- Decode Speed(tokens/s):
- KTransfermor: 8.73 (32 cores) → 11.26 (dual-socket, 2×32 cores) → 13.69 (selectively using 6 experts, v3 only)
- Compared to 4.51 tokens/s in llama.cpp with 2×32 cores, achieving up to **3.03× speedup**.
- Upcoming Open Source Release:
- AMX optimizations and selective expert activation will be open-sourced in v0.3.
- Currently available only in preview binary distribution, which can be found here.
- **Local 236B DeepSeek-Coder-V2:** Running its Q4_K_M version using only 21GB VRAM and 136GB DRAM, attainable on a local desktop machine, which scores even better than GPT4-0613 in [BigCodeBench](https://huggingface.co/blog/leaderboard-bigcodebench).
<p align="center">
<picture>
<img alt="DeepSeek-Coder-V2 Score" src="https://github.com/user-attachments/assets/d052924e-8631-44de-aad2-97c54b965693" width=100%>
</picture>
</p>
- **Faster Speed:** Achieving 126 tokens/s for 2K prompt prefill and 13.6 tokens/s for generation through MoE offloading and injecting advanced kernels from [Llamafile](https://github.com/Mozilla-Ocho/llamafile/tree/main) and [Marlin](https://github.com/IST-DASLab/marlin).
- **VSCode Integration:** Wrapped into an OpenAI and Ollama compatible API for seamless integration as a backend for [Tabby](https://github.com/TabbyML/tabby) and various other frontends.
<p align="center">
https://github.com/user-attachments/assets/4c6a8a38-05aa-497d-8eb1-3a5b3918429c
</p>
<h3>1M Context Local Inference on a Desktop with Only 24GB VRAM</h3> <h3>1M Context Local Inference on a Desktop with Only 24GB VRAM</h3>
<p align="center"> <p align="center">
@ -54,30 +92,7 @@ https://github.com/user-attachments/assets/a865e5e4-bca3-401e-94b8-af3c080e6c12
* **Flexible Sparse Attention Framework**: Offers a flexible block sparse attention framework for CPU offloaded decoding. Compatible with SnapKV, Quest, and InfLLm. Further information is available [here](./doc/en/long_context_introduction.md). * **Flexible Sparse Attention Framework**: Offers a flexible block sparse attention framework for CPU offloaded decoding. Compatible with SnapKV, Quest, and InfLLm. Further information is available [here](./doc/en/long_context_introduction.md).
<div>
<h3>GPT-4-level Local VSCode Copilot on a Desktop with only 24GB VRAM</h3>
</div>
https://github.com/user-attachments/assets/0b9fa2da-66f0-48eb-b4b9-f0e1f06f8927
</p>
- **Local 236B DeepSeek-Coder-V2:** Running its Q4_K_M version using only 21GB VRAM and 136GB DRAM, attainable on a local desktop machine, which scores even better than GPT4-0613 in [BigCodeBench](https://huggingface.co/blog/leaderboard-bigcodebench).
<p align="center">
<picture>
<img alt="DeepSeek-Coder-V2 Score" src="https://github.com/user-attachments/assets/d052924e-8631-44de-aad2-97c54b965693" width=100%>
</picture>
</p>
- **Faster Speed:** Achieving 126 tokens/s for 2K prompt prefill and 13.6 tokens/s for generation through MoE offloading and injecting advanced kernels from [Llamafile](https://github.com/Mozilla-Ocho/llamafile/tree/main) and [Marlin](https://github.com/IST-DASLab/marlin).
- **VSCode Integration:** Wrapped into an OpenAI and Ollama compatible API for seamless integration as a backend for [Tabby](https://github.com/TabbyML/tabby) and various other frontends.
<p align="center">
https://github.com/user-attachments/assets/4c6a8a38-05aa-497d-8eb1-3a5b3918429c
</p>
<strong>More advanced features will coming soon, so stay tuned!</strong> <strong>More advanced features will coming soon, so stay tuned!</strong>

View file

@ -0,0 +1,64 @@
## prerequisites
We run our best performance tests on <br>
cpu: Intel(R) Xeon(R) Gold 6454S 1T DRAM(2 NUMA nodes)<br>
gpu: 4090D 24G VRAM <br>
## bench result
### V0.2
#### settings
- model: DeepseekV3-q4kmint4<br>
- CPU: cpu_model_nameIntel(R) Xeon(R) Gold 6454S, 32 cores per socket, 2 socket, 2numa nodes
- GPU: 4090D 24GVRAM
- we test after enough warm up!
#### memory consumption:
- single socket: 382G DRAM, 12G VRAM
- dual socket: 1T DRAM, 12G VRAM
#### Benchmark Results
"6 experts" case is part of v0.3's preview
| Prompt<br>(500 tokens) | Dual socket Ktrans (6 experts) | Dual socket Ktrans (8 experts) | Single socket Ktrans (6 experts) | Single socket Ktrans (8 experts)| Llama (8 experts) |
| --- | --- | --- | --- | --- | --- |
| Prefill token/s | 97.32 | 82.94 | 65.14 | 54.21 | 10.31 |
| Decode token/s | 13.69 | 12.208 | 10.303 | 8.73 |4.51 |
**The highest speedup reaches up to <u>x3.03</u> in decoding and <u>x9.44</u> in prefill.**
## how to run
### v0.2 showcase
#### single socket version(32 cores)
our local_chat test command is:
``` shell
git clone https://github.com/kvcache-ai/ktransformers.git
cd ktransformers
numactl -N 1 -m 1 python ./ktransformers/local_chat.py --model_path <your model path> --gguf_path <your gguf path> --prompt_file <your promt txt file> --cpu_infer 33 --cache_lens 1536
<when you see chat, then press enter to load the text prompt_file>
```
\<your model path\> can be local or set from onlie hugging face like deepseek-ai/DeepSeek-V3. If onlie encounters connection problem, try use mirror(hf-mirror.com) <br>
\<your gguf path\> can also be onlie, but as its large we recommend you download it and quantize the model to what you want.<br>
the command numactl -N 1 -m 1 aims to adoid data transfer between numa nodes.
### dual socket version(64 cores)
make suer before you install(use install.sh or `make dev_install`), setting the env var `USE_NUMA=1` by `export USE_NUMA=1`(if already installed, reinstall it with this env var set) <br>
our local_chat test command is:
``` shell
git clone https://github.com/kvcache-ai/ktransformers.git
cd ktransformers
export USE_NUMA=1
make dev_install # or sh ./install.sh
python ./ktransformers/local_chat.py --model_path <your model path> --gguf_path <your gguf path> --prompt_file <your promt txt file> --cpu_infer 65 --cache_lens 1536
<when you see chat, then press enter to load the text prompt_file>
```
The parameters meaning is the same. But As we use dual socket, so we set cpu_infer to 65.
## some explanations
1. From our perspective on DeepSeekV2, DeepSeekV3 and DeepSeekR1,
when we slightly decrease the activation experts num in inference,
the output quality doesn't change(within 1% accuracy drop),But the speed of decoding and prefill
is speed up about 30% which is inspiring. So our showcase makes use of this finding,
changing the activation experts of DeepSeekV3/R1 from 8 to 6. <br>
2. Also we want to make further use of our two NUMA nodes on Xeon Gold cpu.
To avoid the cost of data transfer between nodes, we "copy" the critical matrix on
both nodes which takes more memory consumption but accelerates the prefill and decoding process.
But this method takes huge memory and slow when loading weights, So be patient when loading
and monitor the memory usage.(we are considering to make this method as an option)<br>
3. the command args `--cpu_infer 65` specifies how many cores to use(it's ok that it exceeds the physical number,
but it's not the more the better. Adjust it slight lower to your actual number of cores)<br>

View file

@ -230,3 +230,24 @@ elseif(UNIX)
endif() endif()
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so") target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
endif() endif()
# Define the USE_NUMA option
option(USE_NUMA "Disable NUMA support" OFF)
# Check if the USE_NUMA environment variable is set
if(DEFINED ENV{USE_NUMA})
set(USE_NUMA ON)
endif()
if (USE_NUMA)
message(STATUS "NUMA support is enabled")
else()
message(STATUS "NUMA support is disabled")
endif()
find_library(NUMA_LIBRARY NAMES numa)
if (NUMA_LIBRARY AND USE_NUMA)
message(STATUS "NUMA library found: ${NUMA_LIBRARY} - enabling NUMA support")
target_link_libraries(${PROJECT_NAME} PRIVATE ${NUMA_LIBRARY})
target_compile_definitions(${PROJECT_NAME} PRIVATE USE_NUMA)
else()
message(STATUS "NUMA library not found or user not set USE_NUMA - disabling NUMA support")
endif()

View file

@ -10,6 +10,13 @@
#include "backend.h" #include "backend.h"
#ifdef USE_NUMA
#include <numa.h>
#include <numaif.h>
thread_local int Backend::numa_node = -1;
#endif
thread_local int Backend::thread_local_id = -1; thread_local int Backend::thread_local_id = -1;
Backend::Backend(int max_thread_num) { Backend::Backend(int max_thread_num) {
@ -74,6 +81,16 @@ void Backend::do_work_stealing_job(int task_num,
} }
void Backend::process_tasks(int thread_id) { void Backend::process_tasks(int thread_id) {
#ifdef USE_NUMA
if(numa_node == -1){
numa_node = thread_id * numa_num_configured_nodes() / thread_num_;
struct bitmask* mask = numa_bitmask_alloc(numa_num_configured_nodes());
numa_bitmask_setbit(mask, numa_node);
numa_bind(mask);
}
#endif
if (init_func_ != nullptr) { if (init_func_ != nullptr) {
init_func_(thread_id); init_func_(thread_id);
} }

View file

@ -38,6 +38,9 @@ class Backend {
void do_work_stealing_job(int, std::function<void(int)>, void do_work_stealing_job(int, std::function<void(int)>,
std::function<void(int)>, std::function<void(int)>,
std::function<void(int)>); std::function<void(int)>);
#ifdef USE_NUMA
static thread_local int numa_node;
#endif
static thread_local int thread_local_id; static thread_local int thread_local_id;
private: private:

View file

@ -11,11 +11,41 @@
#include <iostream> #include <iostream>
#include <cstdint> #include <cstdint>
#ifdef USE_NUMA
#include <numa.h>
#include <numaif.h>
#endif
MOE::MOE(MOEConfig config) { MOE::MOE(MOEConfig config) {
config_ = config; config_ = config;
gate_proj_ = config_.gate_proj; gate_proj_ = config_.gate_proj;
up_proj_ = config_.up_proj; up_proj_ = config_.up_proj;
down_proj_ = config_.down_proj; down_proj_ = config_.down_proj;
#ifdef USE_NUMA
int numa_nodes = numa_num_configured_nodes();
gate_proj_numa_.resize(numa_nodes);
up_proj_numa_.resize(numa_nodes);
down_proj_numa_.resize(numa_nodes);
size_t exp_inter_hidden_mul_ = (size_t)config.expert_num * config.intermediate_size * config.hidden_size;
for (int i = 0; i < numa_nodes; i++) {
gate_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type), i);
up_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type), i);
down_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type), i);
if (!gate_proj_numa_[i]) {
std::cout << "Memory allocation failed for gate_proj_numa_ on node " << i << std::endl;
}
if (!up_proj_numa_[i]) {
std::cout << "Memory allocation failed for up_proj_numa_ on node " << i << std::endl;
}
if (!down_proj_numa_[i]) {
std::cout << "Memory allocation failed for down_proj_numa_ on node " << i << std::endl;
}
memcpy(gate_proj_numa_[i], gate_proj_, exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type));
memcpy(up_proj_numa_[i], up_proj_, exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type));
memcpy(down_proj_numa_[i], down_proj_, exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type));
}
#endif
std::vector<std::pair<void**, uint64_t>> s_mem_requests; std::vector<std::pair<void**, uint64_t>> s_mem_requests;
s_mem_requests.push_back({(void**)&s_input_fp32_, sizeof(float) * config_.hidden_size}); s_mem_requests.push_back({(void**)&s_input_fp32_, sizeof(float) * config_.hidden_size});
@ -74,6 +104,15 @@ MOE::MOE(MOEConfig config) {
MOE::~MOE() { MOE::~MOE() {
shared_mem_buffer.dealloc(this); shared_mem_buffer.dealloc(this);
#ifdef USE_NUMA
int numa_nodes = numa_num_configured_nodes();
for (int i = 0; i < numa_nodes; i++) {
numa_free(gate_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type));
numa_free(up_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type));
numa_free(down_proj_numa_[i], config_.expert_num * config_.hidden_size * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type));
}
#endif
} }
void MOE::warm_up(Backend* backend) { void MOE::warm_up(Backend* backend) {
@ -125,10 +164,22 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c
int expert_idx = task_id / nth; int expert_idx = task_id / nth;
uint64_t expert_id = expert_ids[expert_idx]; uint64_t expert_id = expert_ids[expert_idx];
int ith = task_id % nth; int ith = task_id % nth;
#ifdef USE_NUMA
void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
#else
void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
#endif
float* gate_output_ptr = s_gate_output_[expert_idx] + ith * config_.stride; float* gate_output_ptr = s_gate_output_[expert_idx] + ith * config_.stride;
llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
#ifdef USE_NUMA
void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
#else
void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
#endif
float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride; float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride;
llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
@ -153,7 +204,13 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c
} }
for (int expert_idx = 0; expert_idx < k; expert_idx++) { for (int expert_idx = 0; expert_idx < k; expert_idx++) {
uint64_t expert_id = expert_ids[expert_idx]; uint64_t expert_id = expert_ids[expert_idx];
#ifdef USE_NUMA
void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
#else
void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
#endif
float* down_output_ptr = s_down_output_[expert_idx] + ith * config_.stride; float* down_output_ptr = s_down_output_[expert_idx] + ith * config_.stride;
llamafile_sgemm(config_.stride, 1, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), s_down_input_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); llamafile_sgemm(config_.stride, 1, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), s_down_input_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
@ -227,11 +284,23 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
uint64_t expert_idx = task_id / nth; uint64_t expert_idx = task_id / nth;
int ith = task_id % nth; int ith = task_id % nth;
void* gate_input_ptr = m_local_gate_input_ptr_[expert_idx]; void* gate_input_ptr = m_local_gate_input_ptr_[expert_idx];
#ifdef USE_NUMA
void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
#else
void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
#endif
float* gate_output_ptr = m_local_gate_output_ptr_[expert_idx] + ith * stride; float* gate_output_ptr = m_local_gate_output_ptr_[expert_idx] + ith * stride;
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
void* up_input_ptr = m_local_up_input_ptr_[expert_idx]; void* up_input_ptr = m_local_up_input_ptr_[expert_idx];
#ifdef USE_NUMA
void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
#else
void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
#endif
float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * stride; float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * stride;
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
for (int i = 0; i < m_local_num_[expert_idx]; i++) { for (int i = 0; i < m_local_num_[expert_idx]; i++) {
@ -249,7 +318,13 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
uint64_t expert_idx = task_id / nth; uint64_t expert_idx = task_id / nth;
int ith = task_id % nth; int ith = task_id % nth;
void* down_input_ptr = m_local_down_input_ptr_[expert_idx]; void* down_input_ptr = m_local_down_input_ptr_[expert_idx];
#ifdef USE_NUMA
void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
#else
void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
#endif
float* down_output_ptr = m_local_down_output_ptr_[expert_idx] + ith * stride; float* down_output_ptr = m_local_down_output_ptr_[expert_idx] + ith * stride;
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); llamafile_sgemm(stride, m_local_num_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
}, nullptr); }, nullptr);

View file

@ -61,6 +61,12 @@ class MOE {
void* up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)] void* up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
void* down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)] void* down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)]
#ifdef USE_NUMA
std::vector<void*> gate_proj_numa_; // [numa_num, expert_num * intermediate_size * hidden_size ( /32 if quantized)]
std::vector<void*> up_proj_numa_; // [numa_num, expert_num * intermediate_size * hidden_size ( /32 if quantized)]
std::vector<void*> down_proj_numa_; // [numa_num, expert_num * hidden_size * intermediate_size ( /32 if quantized)]
#endif
float* s_input_fp32_; // [hidden_size] float* s_input_fp32_; // [hidden_size]
uint8_t* s_gate_input_; // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)] uint8_t* s_gate_input_; // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]
uint8_t* s_up_input_; // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)] uint8_t* s_up_input_; // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]

View file

@ -1,24 +1,140 @@
# """
# Description :
# Author : Boxin Zhang, Azure-Tang
# Version : 0.1.0
# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
# """
# import asyncio
# import os
# import platform
# import sys
# project_dir = os.path.dirname(os.path.dirname(__file__))
# sys.path.insert(0, project_dir)
# from ktransformers.server.args import ArgumentParser
# from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
# from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
# from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
# from ktransformers.models.modeling_llama import LlamaForCausalLM
# from ktransformers.models.modeling_mixtral import MixtralForCausalLM
# from ktransformers.server.config.config import Config
# custom_models = {
# "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
# "DeepseekV3ForCausalLM": DeepseekV3ForCausalLM,
# "Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
# "LlamaForCausalLM": LlamaForCausalLM,
# "MixtralForCausalLM": MixtralForCausalLM,
# }
# ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
# default_optimize_rules = {
# "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
# "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat.yaml",
# "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml",
# "LlamaForCausalLM": ktransformer_rules_dir + "Internlm2_5-7b-Chat-1m.yaml",
# "MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml",
# }
# def local_chat():
# config = Config()
# arg_parser = ArgumentParser(config)
# # 初始化消息
# arg_parser.parse_args()
# if config.backend_type == "transformers":
# from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
# elif config.backend_type == "exllamav2":
# from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface
# elif config.backend_type == "ktransformers":
# from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface
# else:
# raise NotImplementedError(f"{config.backend_type} not implemented")
# interface = BackendInterface(config)
# system = platform.system()
# if system == "Windows":
# os.system("cls")
# else:
# os.system("clear")
# # add a history chat content
# his_content = []
# while True:
# content = input("Chat: ")
# if content.startswith('"""'): # prefix """
# # multi lines input
# content = content[3:] + "\n"
# while True:
# line = input("")
# if line.endswith('"""'):
# # end multi lines input
# line = line[:-3] # suffix """
# if line:
# content += line + "\n"
# break
# else:
# content += line + "\n"
# if content == "":
# if not config.prompt_file:
# content = "hi"
# else:
# content = open(config.prompt_file, "r").read()
# print("User: ", content)
# elif os.path.isfile(content):
# content = open(content, "r").read()
# print("User: ", content)
# messages = his_content + [{"role": "user", "content": content}]
# async def async_inference(messages):
# generated = ""
# async for token in interface.inference(messages, "local_chat"):
# generated += token
# return generated
# generated = asyncio.run(async_inference(messages))
# his_content += [
# {"role": "user", "content": content},
# {"role": "assistant", "content": generated},
# ]
# if __name__ == "__main__":
# local_chat()
""" """
Description : Description :
Author : Boxin Zhang, Azure-Tang Author : Boxin Zhang, Azure-Tang
Version : 0.1.0 Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved. Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
""" """
import asyncio
import os import os
import platform import platform
import sys import sys
project_dir = os.path.dirname(os.path.dirname(__file__)) project_dir = os.path.dirname(os.path.dirname(__file__))
sys.path.insert(0, project_dir) sys.path.insert(0, project_dir)
from ktransformers.server.args import ArgumentParser import torch
import logging
from transformers import (
AutoTokenizer,
AutoConfig,
AutoModelForCausalLM,
GenerationConfig,
TextStreamer,
)
import json
import fire
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
custom_models = { custom_models = {
@ -29,7 +145,9 @@ custom_models = {
"MixtralForCausalLM": MixtralForCausalLM, "MixtralForCausalLM": MixtralForCausalLM,
} }
ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" ktransformer_rules_dir = (
os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
)
default_optimize_rules = { default_optimize_rules = {
"DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat.yaml", "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat.yaml",
@ -39,28 +157,85 @@ default_optimize_rules = {
} }
def local_chat(): def local_chat(
config = Config() model_path: str | None = None,
arg_parser = ArgumentParser(config) optimize_rule_path: str = None,
# 初始化消息 gguf_path: str | None = None,
arg_parser.parse_args() max_new_tokens: int = 1000,
if config.backend_type == "transformers": cpu_infer: int = Config().cpu_infer,
from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface use_cuda_graph: bool = True,
elif config.backend_type == "exllamav2": prompt_file : str | None = None,
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface mode: str = "normal",
elif config.backend_type == "ktransformers": ):
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface
torch.set_grad_enabled(False)
Config().cpu_infer = cpu_infer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if mode == 'long_context':
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
torch.set_default_dtype(torch.float16)
else: else:
raise NotImplementedError(f"{config.backend_type} not implemented") torch.set_default_dtype(config.torch_dtype)
interface = BackendInterface(config)
with torch.device("meta"):
if config.architectures[0] in custom_models:
print("using custom modeling_xxx.py.")
if (
"Qwen2Moe" in config.architectures[0]
): # Qwen2Moe must use flash_attention_2 to avoid overflow.
config._attn_implementation = "flash_attention_2"
if "Llama" in config.architectures[0]:
config._attn_implementation = "eager"
if "Mixtral" in config.architectures[0]:
config._attn_implementation = "flash_attention_2"
model = custom_models[config.architectures[0]](config)
else:
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=True, attn_implementation="flash_attention_2"
)
if optimize_rule_path is None:
if config.architectures[0] in default_optimize_rules:
print("using default_optimize_rule for", config.architectures[0])
optimize_rule_path = default_optimize_rules[config.architectures[0]]
else:
optimize_rule_path = input(
"please input the path of your rule file(yaml file containing optimize rules):"
)
if gguf_path is None:
gguf_path = input(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
)
optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config)
try:
model.generation_config = GenerationConfig.from_pretrained(model_path)
except:
gen_config = GenerationConfig(
max_length=128,
temperature=0.7,
top_p=0.9,
do_sample=True
)
model.generation_config = gen_config
# model.generation_config = GenerationConfig.from_pretrained(model_path)
if model.generation_config.pad_token_id is None:
model.generation_config.pad_token_id = model.generation_config.eos_token_id
model.eval()
logging.basicConfig(level=logging.INFO)
system = platform.system() system = platform.system()
if system == "Windows": if system == "Windows":
os.system("cls") os.system("cls")
else: else:
os.system("clear") os.system("clear")
# add a history chat content
his_content = []
while True: while True:
content = input("Chat: ") content = input("Chat: ")
if content.startswith('"""'): # prefix """ if content.startswith('"""'): # prefix """
@ -76,29 +251,28 @@ def local_chat():
break break
else: else:
content += line + "\n" content += line + "\n"
if content == "": if content == "":
if not config.prompt_file: if prompt_file != None:
content = "hi" content = open(prompt_file, "r").read()
else: else:
content = open(config.prompt_file, "r").read() content = "Please write a piece of quicksort code in C++."
print("User: ", content)
elif os.path.isfile(content): elif os.path.isfile(content):
content = open(content, "r").read() content = open(content, "r").read()
print("User: ", content) messages = [{"role": "user", "content": content}]
messages = his_content + [{"role": "user", "content": content}] input_tensor = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
async def async_inference(messages): )
generated = "" if mode == 'long_context':
async for token in interface.inference(messages, "local_chat"): assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
generated += token "please change max_seq_len in ~/.ktransformers/config.yaml"
return generated torch.set_default_dtype(
torch.bfloat16
generated = asyncio.run(async_inference(messages)) ) # TODO: Remove this, replace dtype using config
his_content += [ generated = prefill_and_generate(
{"role": "user", "content": content}, model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode
{"role": "assistant", "content": generated}, )
]
if __name__ == "__main__": if __name__ == "__main__":
local_chat() fire.Fire(local_chat)

View file

@ -278,13 +278,15 @@ class CMakeBuild(BuildExtension):
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
if hasattr(self, "parallel") and self.parallel: if hasattr(self, "parallel") and self.parallel:
build_args += [f"-j{self.parallel}"] build_args += [f"-j{self.parallel}"]
print("CMake args:", cmake_args)
build_temp = Path(ext.sourcedir) / "build" build_temp = Path(ext.sourcedir) / "build"
if not build_temp.exists(): if not build_temp.exists():
build_temp.mkdir(parents=True) build_temp.mkdir(parents=True)
subprocess.run( result = subprocess.run(
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True , capture_output=True
) )
print("Standard output:", result.stdout)
print("Standard error:", result.stderr)
subprocess.run( subprocess.run(
["cmake", "--build", ".", *build_args], cwd=build_temp, check=True ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
) )