mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 12:09:48 +00:00
⚡ v0.2 ongoing
This commit is contained in:
parent
bf1d413be0
commit
098602b08f
11 changed files with 450 additions and 70 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -19,3 +19,6 @@ ktransformers/server/local_store/
|
||||||
ktransformers/server_test1.db
|
ktransformers/server_test1.db
|
||||||
*.patch
|
*.patch
|
||||||
img/
|
img/
|
||||||
|
tmp1.txt
|
||||||
|
test_65_300_1536.txt
|
||||||
|
test.txt
|
||||||
|
|
2
Makefile
2
Makefile
|
@ -17,5 +17,5 @@ dev_install:
|
||||||
pip install -r requirements-local_chat.txt
|
pip install -r requirements-local_chat.txt
|
||||||
|
|
||||||
echo "Installing ktransformers"
|
echo "Installing ktransformers"
|
||||||
KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . --no-build-isolation
|
KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . -v --no-build-isolation
|
||||||
echo "Installation completed successfully"
|
echo "Installation completed successfully"
|
61
README.md
61
README.md
|
@ -23,6 +23,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
|
||||||
|
|
||||||
<h2 id="Updates">🔥 Updates</h2>
|
<h2 id="Updates">🔥 Updates</h2>
|
||||||
|
|
||||||
|
* **Fed 10, 2025**: Support DeepseekR1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to XXX speedup. The Detailed tutorial is [here](./doc/en/DeepseekR1_V3_tutorial.md)
|
||||||
* **Aug 28, 2024**: Support 1M context under the InternLM2.5-7B-Chat-1M model, utilizing 24GB of VRAM and 150GB of DRAM. The detailed tutorial is [here](./doc/en/long_context_tutorial.md).
|
* **Aug 28, 2024**: Support 1M context under the InternLM2.5-7B-Chat-1M model, utilizing 24GB of VRAM and 150GB of DRAM. The detailed tutorial is [here](./doc/en/long_context_tutorial.md).
|
||||||
* **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G.
|
* **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G.
|
||||||
* **Aug 15, 2024**: Update detailed [TUTORIAL](doc/en/injection_tutorial.md) for injection and multi-GPU.
|
* **Aug 15, 2024**: Update detailed [TUTORIAL](doc/en/injection_tutorial.md) for injection and multi-GPU.
|
||||||
|
@ -31,6 +32,43 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
|
||||||
* **Aug 9, 2024**: Support windows native.
|
* **Aug 9, 2024**: Support windows native.
|
||||||
|
|
||||||
<h2 id="show-cases">🔥 Show Cases</h2>
|
<h2 id="show-cases">🔥 Show Cases</h2>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<h3>GPT-4/o1-level Local VSCode Copilot on a Desktop with only 24GB VRAM</h3>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285
|
||||||
|
|
||||||
|
</p>
|
||||||
|
|
||||||
|
- **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 12GB VRAM and 382GB DRAM.
|
||||||
|
- Prefill Speed:
|
||||||
|
- KTransfermor: 54.21 (32 cores) → 74.362 (dual-socket, 2×32 cores) → xxx (optimized AMX-based MoE kernel, v3 only) → XXX (selectively using 6 experts, v3 only)
|
||||||
|
- Compared to 4.51 tokens/s in llama.cpp with 2×32 cores, achieving up to **XXX× speedup**.
|
||||||
|
- Decode Speed(tokens/s):
|
||||||
|
- KTransfermor: 8.73 (32 cores) → 11.26 (dual-socket, 2×32 cores) → 13.69 (selectively using 6 experts, v3 only)
|
||||||
|
- Compared to 4.51 tokens/s in llama.cpp with 2×32 cores, achieving up to **3.03× speedup**.
|
||||||
|
- Upcoming Open Source Release:
|
||||||
|
- AMX optimizations and selective expert activation will be open-sourced in v0.3.
|
||||||
|
- Currently available only in preview binary distribution, which can be found here.
|
||||||
|
|
||||||
|
- **Local 236B DeepSeek-Coder-V2:** Running its Q4_K_M version using only 21GB VRAM and 136GB DRAM, attainable on a local desktop machine, which scores even better than GPT4-0613 in [BigCodeBench](https://huggingface.co/blog/leaderboard-bigcodebench).
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<picture>
|
||||||
|
<img alt="DeepSeek-Coder-V2 Score" src="https://github.com/user-attachments/assets/d052924e-8631-44de-aad2-97c54b965693" width=100%>
|
||||||
|
</picture>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
- **Faster Speed:** Achieving 126 tokens/s for 2K prompt prefill and 13.6 tokens/s for generation through MoE offloading and injecting advanced kernels from [Llamafile](https://github.com/Mozilla-Ocho/llamafile/tree/main) and [Marlin](https://github.com/IST-DASLab/marlin).
|
||||||
|
- **VSCode Integration:** Wrapped into an OpenAI and Ollama compatible API for seamless integration as a backend for [Tabby](https://github.com/TabbyML/tabby) and various other frontends.
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
|
||||||
|
https://github.com/user-attachments/assets/4c6a8a38-05aa-497d-8eb1-3a5b3918429c
|
||||||
|
|
||||||
|
</p>
|
||||||
|
|
||||||
<h3>1M Context Local Inference on a Desktop with Only 24GB VRAM</h3>
|
<h3>1M Context Local Inference on a Desktop with Only 24GB VRAM</h3>
|
||||||
<p align="center">
|
<p align="center">
|
||||||
|
|
||||||
|
@ -54,30 +92,7 @@ https://github.com/user-attachments/assets/a865e5e4-bca3-401e-94b8-af3c080e6c12
|
||||||
|
|
||||||
* **Flexible Sparse Attention Framework**: Offers a flexible block sparse attention framework for CPU offloaded decoding. Compatible with SnapKV, Quest, and InfLLm. Further information is available [here](./doc/en/long_context_introduction.md).
|
* **Flexible Sparse Attention Framework**: Offers a flexible block sparse attention framework for CPU offloaded decoding. Compatible with SnapKV, Quest, and InfLLm. Further information is available [here](./doc/en/long_context_introduction.md).
|
||||||
|
|
||||||
<div>
|
|
||||||
<h3>GPT-4-level Local VSCode Copilot on a Desktop with only 24GB VRAM</h3>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/0b9fa2da-66f0-48eb-b4b9-f0e1f06f8927
|
|
||||||
|
|
||||||
</p>
|
|
||||||
|
|
||||||
- **Local 236B DeepSeek-Coder-V2:** Running its Q4_K_M version using only 21GB VRAM and 136GB DRAM, attainable on a local desktop machine, which scores even better than GPT4-0613 in [BigCodeBench](https://huggingface.co/blog/leaderboard-bigcodebench).
|
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<picture>
|
|
||||||
<img alt="DeepSeek-Coder-V2 Score" src="https://github.com/user-attachments/assets/d052924e-8631-44de-aad2-97c54b965693" width=100%>
|
|
||||||
</picture>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
- **Faster Speed:** Achieving 126 tokens/s for 2K prompt prefill and 13.6 tokens/s for generation through MoE offloading and injecting advanced kernels from [Llamafile](https://github.com/Mozilla-Ocho/llamafile/tree/main) and [Marlin](https://github.com/IST-DASLab/marlin).
|
|
||||||
- **VSCode Integration:** Wrapped into an OpenAI and Ollama compatible API for seamless integration as a backend for [Tabby](https://github.com/TabbyML/tabby) and various other frontends.
|
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/4c6a8a38-05aa-497d-8eb1-3a5b3918429c
|
|
||||||
|
|
||||||
</p>
|
|
||||||
|
|
||||||
<strong>More advanced features will coming soon, so stay tuned!</strong>
|
<strong>More advanced features will coming soon, so stay tuned!</strong>
|
||||||
|
|
||||||
|
|
64
doc/en/DeepseekR1_V3_tutorial.md
Normal file
64
doc/en/DeepseekR1_V3_tutorial.md
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
## prerequisites
|
||||||
|
We run our best performance tests on <br>
|
||||||
|
cpu: Intel(R) Xeon(R) Gold 6454S 1T DRAM(2 NUMA nodes)<br>
|
||||||
|
gpu: 4090D 24G VRAM <br>
|
||||||
|
## bench result
|
||||||
|
### V0.2
|
||||||
|
#### settings
|
||||||
|
- model: DeepseekV3-q4km(int4)<br>
|
||||||
|
- CPU: cpu_model_name:Intel(R) Xeon(R) Gold 6454S, 32 cores per socket, 2 socket, 2numa nodes
|
||||||
|
- GPU: 4090D 24GVRAM
|
||||||
|
- we test after enough warm up!
|
||||||
|
#### memory consumption:
|
||||||
|
- single socket: 382G DRAM, 12G VRAM
|
||||||
|
- dual socket: 1T DRAM, 12G VRAM
|
||||||
|
|
||||||
|
#### Benchmark Results
|
||||||
|
|
||||||
|
"6 experts" case is part of v0.3's preview
|
||||||
|
|
||||||
|
| Prompt<br>(500 tokens) | Dual socket Ktrans (6 experts) | Dual socket Ktrans (8 experts) | Single socket Ktrans (6 experts) | Single socket Ktrans (8 experts)| Llama (8 experts) |
|
||||||
|
| --- | --- | --- | --- | --- | --- |
|
||||||
|
| Prefill token/s | 97.32 | 82.94 | 65.14 | 54.21 | 10.31 |
|
||||||
|
| Decode token/s | 13.69 | 12.208 | 10.303 | 8.73 |4.51 |
|
||||||
|
|
||||||
|
**The highest speedup reaches up to <u>x3.03</u> in decoding and <u>x9.44</u> in prefill.**
|
||||||
|
|
||||||
|
## how to run
|
||||||
|
### v0.2 showcase
|
||||||
|
#### single socket version(32 cores)
|
||||||
|
our local_chat test command is:
|
||||||
|
``` shell
|
||||||
|
git clone https://github.com/kvcache-ai/ktransformers.git
|
||||||
|
cd ktransformers
|
||||||
|
numactl -N 1 -m 1 python ./ktransformers/local_chat.py --model_path <your model path> --gguf_path <your gguf path> --prompt_file <your promt txt file> --cpu_infer 33 --cache_lens 1536
|
||||||
|
<when you see chat, then press enter to load the text prompt_file>
|
||||||
|
```
|
||||||
|
\<your model path\> can be local or set from onlie hugging face like deepseek-ai/DeepSeek-V3. If onlie encounters connection problem, try use mirror(hf-mirror.com) <br>
|
||||||
|
\<your gguf path\> can also be onlie, but as its large we recommend you download it and quantize the model to what you want.<br>
|
||||||
|
the command numactl -N 1 -m 1 aims to adoid data transfer between numa nodes.
|
||||||
|
### dual socket version(64 cores)
|
||||||
|
make suer before you install(use install.sh or `make dev_install`), setting the env var `USE_NUMA=1` by `export USE_NUMA=1`(if already installed, reinstall it with this env var set) <br>
|
||||||
|
our local_chat test command is:
|
||||||
|
``` shell
|
||||||
|
git clone https://github.com/kvcache-ai/ktransformers.git
|
||||||
|
cd ktransformers
|
||||||
|
export USE_NUMA=1
|
||||||
|
make dev_install # or sh ./install.sh
|
||||||
|
python ./ktransformers/local_chat.py --model_path <your model path> --gguf_path <your gguf path> --prompt_file <your promt txt file> --cpu_infer 65 --cache_lens 1536
|
||||||
|
<when you see chat, then press enter to load the text prompt_file>
|
||||||
|
```
|
||||||
|
The parameters meaning is the same. But As we use dual socket, so we set cpu_infer to 65.
|
||||||
|
## some explanations
|
||||||
|
1. From our perspective on DeepSeekV2, DeepSeekV3 and DeepSeekR1,
|
||||||
|
when we slightly decrease the activation experts num in inference,
|
||||||
|
the output quality doesn't change(within 1% accuracy drop),But the speed of decoding and prefill
|
||||||
|
is speed up about 30% which is inspiring. So our showcase makes use of this finding,
|
||||||
|
changing the activation experts of DeepSeekV3/R1 from 8 to 6. <br>
|
||||||
|
2. Also we want to make further use of our two NUMA nodes on Xeon Gold cpu.
|
||||||
|
To avoid the cost of data transfer between nodes, we "copy" the critical matrix on
|
||||||
|
both nodes which takes more memory consumption but accelerates the prefill and decoding process.
|
||||||
|
But this method takes huge memory and slow when loading weights, So be patient when loading
|
||||||
|
and monitor the memory usage.(we are considering to make this method as an option)<br>
|
||||||
|
3. the command args `--cpu_infer 65` specifies how many cores to use(it's ok that it exceeds the physical number,
|
||||||
|
but it's not the more the better. Adjust it slight lower to your actual number of cores)<br>
|
|
@ -230,3 +230,24 @@ elseif(UNIX)
|
||||||
endif()
|
endif()
|
||||||
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
|
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# Define the USE_NUMA option
|
||||||
|
option(USE_NUMA "Disable NUMA support" OFF)
|
||||||
|
# Check if the USE_NUMA environment variable is set
|
||||||
|
if(DEFINED ENV{USE_NUMA})
|
||||||
|
set(USE_NUMA ON)
|
||||||
|
endif()
|
||||||
|
if (USE_NUMA)
|
||||||
|
message(STATUS "NUMA support is enabled")
|
||||||
|
else()
|
||||||
|
message(STATUS "NUMA support is disabled")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
find_library(NUMA_LIBRARY NAMES numa)
|
||||||
|
if (NUMA_LIBRARY AND USE_NUMA)
|
||||||
|
message(STATUS "NUMA library found: ${NUMA_LIBRARY} - enabling NUMA support")
|
||||||
|
target_link_libraries(${PROJECT_NAME} PRIVATE ${NUMA_LIBRARY})
|
||||||
|
target_compile_definitions(${PROJECT_NAME} PRIVATE USE_NUMA)
|
||||||
|
else()
|
||||||
|
message(STATUS "NUMA library not found or user not set USE_NUMA - disabling NUMA support")
|
||||||
|
endif()
|
||||||
|
|
|
@ -10,6 +10,13 @@
|
||||||
|
|
||||||
#include "backend.h"
|
#include "backend.h"
|
||||||
|
|
||||||
|
#ifdef USE_NUMA
|
||||||
|
#include <numa.h>
|
||||||
|
#include <numaif.h>
|
||||||
|
|
||||||
|
thread_local int Backend::numa_node = -1;
|
||||||
|
#endif
|
||||||
|
|
||||||
thread_local int Backend::thread_local_id = -1;
|
thread_local int Backend::thread_local_id = -1;
|
||||||
|
|
||||||
Backend::Backend(int max_thread_num) {
|
Backend::Backend(int max_thread_num) {
|
||||||
|
@ -74,6 +81,16 @@ void Backend::do_work_stealing_job(int task_num,
|
||||||
}
|
}
|
||||||
|
|
||||||
void Backend::process_tasks(int thread_id) {
|
void Backend::process_tasks(int thread_id) {
|
||||||
|
|
||||||
|
#ifdef USE_NUMA
|
||||||
|
if(numa_node == -1){
|
||||||
|
numa_node = thread_id * numa_num_configured_nodes() / thread_num_;
|
||||||
|
struct bitmask* mask = numa_bitmask_alloc(numa_num_configured_nodes());
|
||||||
|
numa_bitmask_setbit(mask, numa_node);
|
||||||
|
numa_bind(mask);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
if (init_func_ != nullptr) {
|
if (init_func_ != nullptr) {
|
||||||
init_func_(thread_id);
|
init_func_(thread_id);
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,6 +38,9 @@ class Backend {
|
||||||
void do_work_stealing_job(int, std::function<void(int)>,
|
void do_work_stealing_job(int, std::function<void(int)>,
|
||||||
std::function<void(int)>,
|
std::function<void(int)>,
|
||||||
std::function<void(int)>);
|
std::function<void(int)>);
|
||||||
|
#ifdef USE_NUMA
|
||||||
|
static thread_local int numa_node;
|
||||||
|
#endif
|
||||||
static thread_local int thread_local_id;
|
static thread_local int thread_local_id;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -11,12 +11,42 @@
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
|
#ifdef USE_NUMA
|
||||||
|
#include <numa.h>
|
||||||
|
#include <numaif.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
MOE::MOE(MOEConfig config) {
|
MOE::MOE(MOEConfig config) {
|
||||||
config_ = config;
|
config_ = config;
|
||||||
gate_proj_ = config_.gate_proj;
|
gate_proj_ = config_.gate_proj;
|
||||||
up_proj_ = config_.up_proj;
|
up_proj_ = config_.up_proj;
|
||||||
down_proj_ = config_.down_proj;
|
down_proj_ = config_.down_proj;
|
||||||
|
|
||||||
|
#ifdef USE_NUMA
|
||||||
|
int numa_nodes = numa_num_configured_nodes();
|
||||||
|
gate_proj_numa_.resize(numa_nodes);
|
||||||
|
up_proj_numa_.resize(numa_nodes);
|
||||||
|
down_proj_numa_.resize(numa_nodes);
|
||||||
|
size_t exp_inter_hidden_mul_ = (size_t)config.expert_num * config.intermediate_size * config.hidden_size;
|
||||||
|
for (int i = 0; i < numa_nodes; i++) {
|
||||||
|
gate_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type), i);
|
||||||
|
up_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type), i);
|
||||||
|
down_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type), i);
|
||||||
|
if (!gate_proj_numa_[i]) {
|
||||||
|
std::cout << "Memory allocation failed for gate_proj_numa_ on node " << i << std::endl;
|
||||||
|
}
|
||||||
|
if (!up_proj_numa_[i]) {
|
||||||
|
std::cout << "Memory allocation failed for up_proj_numa_ on node " << i << std::endl;
|
||||||
|
}
|
||||||
|
if (!down_proj_numa_[i]) {
|
||||||
|
std::cout << "Memory allocation failed for down_proj_numa_ on node " << i << std::endl;
|
||||||
|
}
|
||||||
|
memcpy(gate_proj_numa_[i], gate_proj_, exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type));
|
||||||
|
memcpy(up_proj_numa_[i], up_proj_, exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type));
|
||||||
|
memcpy(down_proj_numa_[i], down_proj_, exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
std::vector<std::pair<void**, uint64_t>> s_mem_requests;
|
std::vector<std::pair<void**, uint64_t>> s_mem_requests;
|
||||||
s_mem_requests.push_back({(void**)&s_input_fp32_, sizeof(float) * config_.hidden_size});
|
s_mem_requests.push_back({(void**)&s_input_fp32_, sizeof(float) * config_.hidden_size});
|
||||||
s_mem_requests.push_back({(void**)&s_gate_input_, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)});
|
s_mem_requests.push_back({(void**)&s_gate_input_, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)});
|
||||||
|
@ -74,6 +104,15 @@ MOE::MOE(MOEConfig config) {
|
||||||
|
|
||||||
MOE::~MOE() {
|
MOE::~MOE() {
|
||||||
shared_mem_buffer.dealloc(this);
|
shared_mem_buffer.dealloc(this);
|
||||||
|
|
||||||
|
#ifdef USE_NUMA
|
||||||
|
int numa_nodes = numa_num_configured_nodes();
|
||||||
|
for (int i = 0; i < numa_nodes; i++) {
|
||||||
|
numa_free(gate_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type));
|
||||||
|
numa_free(up_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type));
|
||||||
|
numa_free(down_proj_numa_[i], config_.expert_num * config_.hidden_size * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void MOE::warm_up(Backend* backend) {
|
void MOE::warm_up(Backend* backend) {
|
||||||
|
@ -125,10 +164,22 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c
|
||||||
int expert_idx = task_id / nth;
|
int expert_idx = task_id / nth;
|
||||||
uint64_t expert_id = expert_ids[expert_idx];
|
uint64_t expert_id = expert_ids[expert_idx];
|
||||||
int ith = task_id % nth;
|
int ith = task_id % nth;
|
||||||
|
|
||||||
|
#ifdef USE_NUMA
|
||||||
|
void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
|
||||||
|
#else
|
||||||
void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
|
void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
|
||||||
|
#endif
|
||||||
|
|
||||||
float* gate_output_ptr = s_gate_output_[expert_idx] + ith * config_.stride;
|
float* gate_output_ptr = s_gate_output_[expert_idx] + ith * config_.stride;
|
||||||
llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
|
llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
|
||||||
|
|
||||||
|
#ifdef USE_NUMA
|
||||||
|
void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
|
||||||
|
#else
|
||||||
void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
|
void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
|
||||||
|
#endif
|
||||||
|
|
||||||
float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride;
|
float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride;
|
||||||
llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
|
llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
|
||||||
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
|
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
|
||||||
|
@ -153,7 +204,13 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c
|
||||||
}
|
}
|
||||||
for (int expert_idx = 0; expert_idx < k; expert_idx++) {
|
for (int expert_idx = 0; expert_idx < k; expert_idx++) {
|
||||||
uint64_t expert_id = expert_ids[expert_idx];
|
uint64_t expert_id = expert_ids[expert_idx];
|
||||||
|
|
||||||
|
#ifdef USE_NUMA
|
||||||
|
void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
|
||||||
|
#else
|
||||||
void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
|
void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
|
||||||
|
#endif
|
||||||
|
|
||||||
float* down_output_ptr = s_down_output_[expert_idx] + ith * config_.stride;
|
float* down_output_ptr = s_down_output_[expert_idx] + ith * config_.stride;
|
||||||
llamafile_sgemm(config_.stride, 1, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), s_down_input_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
|
llamafile_sgemm(config_.stride, 1, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), s_down_input_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
|
||||||
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
|
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
|
||||||
|
@ -227,11 +284,23 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
|
||||||
uint64_t expert_idx = task_id / nth;
|
uint64_t expert_idx = task_id / nth;
|
||||||
int ith = task_id % nth;
|
int ith = task_id % nth;
|
||||||
void* gate_input_ptr = m_local_gate_input_ptr_[expert_idx];
|
void* gate_input_ptr = m_local_gate_input_ptr_[expert_idx];
|
||||||
|
|
||||||
|
#ifdef USE_NUMA
|
||||||
|
void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
|
||||||
|
#else
|
||||||
void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
|
void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
|
||||||
|
#endif
|
||||||
|
|
||||||
float* gate_output_ptr = m_local_gate_output_ptr_[expert_idx] + ith * stride;
|
float* gate_output_ptr = m_local_gate_output_ptr_[expert_idx] + ith * stride;
|
||||||
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
|
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
|
||||||
void* up_input_ptr = m_local_up_input_ptr_[expert_idx];
|
void* up_input_ptr = m_local_up_input_ptr_[expert_idx];
|
||||||
|
|
||||||
|
#ifdef USE_NUMA
|
||||||
|
void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
|
||||||
|
#else
|
||||||
void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
|
void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
|
||||||
|
#endif
|
||||||
|
|
||||||
float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * stride;
|
float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * stride;
|
||||||
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
|
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
|
||||||
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
|
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
|
||||||
|
@ -249,7 +318,13 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
|
||||||
uint64_t expert_idx = task_id / nth;
|
uint64_t expert_idx = task_id / nth;
|
||||||
int ith = task_id % nth;
|
int ith = task_id % nth;
|
||||||
void* down_input_ptr = m_local_down_input_ptr_[expert_idx];
|
void* down_input_ptr = m_local_down_input_ptr_[expert_idx];
|
||||||
|
|
||||||
|
#ifdef USE_NUMA
|
||||||
|
void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
|
||||||
|
#else
|
||||||
void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
|
void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
|
||||||
|
#endif
|
||||||
|
|
||||||
float* down_output_ptr = m_local_down_output_ptr_[expert_idx] + ith * stride;
|
float* down_output_ptr = m_local_down_output_ptr_[expert_idx] + ith * stride;
|
||||||
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
|
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
|
||||||
}, nullptr);
|
}, nullptr);
|
||||||
|
|
|
@ -61,6 +61,12 @@ class MOE {
|
||||||
void* up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
|
void* up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
|
||||||
void* down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)]
|
void* down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)]
|
||||||
|
|
||||||
|
#ifdef USE_NUMA
|
||||||
|
std::vector<void*> gate_proj_numa_; // [numa_num, expert_num * intermediate_size * hidden_size ( /32 if quantized)]
|
||||||
|
std::vector<void*> up_proj_numa_; // [numa_num, expert_num * intermediate_size * hidden_size ( /32 if quantized)]
|
||||||
|
std::vector<void*> down_proj_numa_; // [numa_num, expert_num * hidden_size * intermediate_size ( /32 if quantized)]
|
||||||
|
#endif
|
||||||
|
|
||||||
float* s_input_fp32_; // [hidden_size]
|
float* s_input_fp32_; // [hidden_size]
|
||||||
uint8_t* s_gate_input_; // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]
|
uint8_t* s_gate_input_; // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]
|
||||||
uint8_t* s_up_input_; // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]
|
uint8_t* s_up_input_; // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]
|
||||||
|
|
|
@ -1,3 +1,109 @@
|
||||||
|
# """
|
||||||
|
# Description :
|
||||||
|
# Author : Boxin Zhang, Azure-Tang
|
||||||
|
# Version : 0.1.0
|
||||||
|
# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||||
|
# """
|
||||||
|
|
||||||
|
# import asyncio
|
||||||
|
# import os
|
||||||
|
# import platform
|
||||||
|
# import sys
|
||||||
|
# project_dir = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
# sys.path.insert(0, project_dir)
|
||||||
|
# from ktransformers.server.args import ArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
# from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
|
||||||
|
# from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
|
||||||
|
# from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
|
||||||
|
# from ktransformers.models.modeling_llama import LlamaForCausalLM
|
||||||
|
# from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
||||||
|
# from ktransformers.server.config.config import Config
|
||||||
|
|
||||||
|
# custom_models = {
|
||||||
|
# "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
|
||||||
|
# "DeepseekV3ForCausalLM": DeepseekV3ForCausalLM,
|
||||||
|
# "Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
|
||||||
|
# "LlamaForCausalLM": LlamaForCausalLM,
|
||||||
|
# "MixtralForCausalLM": MixtralForCausalLM,
|
||||||
|
# }
|
||||||
|
|
||||||
|
# ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
|
||||||
|
# default_optimize_rules = {
|
||||||
|
# "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
|
||||||
|
# "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat.yaml",
|
||||||
|
# "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml",
|
||||||
|
# "LlamaForCausalLM": ktransformer_rules_dir + "Internlm2_5-7b-Chat-1m.yaml",
|
||||||
|
# "MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml",
|
||||||
|
# }
|
||||||
|
|
||||||
|
|
||||||
|
# def local_chat():
|
||||||
|
# config = Config()
|
||||||
|
# arg_parser = ArgumentParser(config)
|
||||||
|
# # 初始化消息
|
||||||
|
# arg_parser.parse_args()
|
||||||
|
# if config.backend_type == "transformers":
|
||||||
|
# from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
|
||||||
|
# elif config.backend_type == "exllamav2":
|
||||||
|
# from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface
|
||||||
|
# elif config.backend_type == "ktransformers":
|
||||||
|
# from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface
|
||||||
|
# else:
|
||||||
|
# raise NotImplementedError(f"{config.backend_type} not implemented")
|
||||||
|
# interface = BackendInterface(config)
|
||||||
|
|
||||||
|
# system = platform.system()
|
||||||
|
# if system == "Windows":
|
||||||
|
# os.system("cls")
|
||||||
|
# else:
|
||||||
|
# os.system("clear")
|
||||||
|
# # add a history chat content
|
||||||
|
# his_content = []
|
||||||
|
# while True:
|
||||||
|
# content = input("Chat: ")
|
||||||
|
# if content.startswith('"""'): # prefix """
|
||||||
|
# # multi lines input
|
||||||
|
# content = content[3:] + "\n"
|
||||||
|
# while True:
|
||||||
|
# line = input("")
|
||||||
|
# if line.endswith('"""'):
|
||||||
|
# # end multi lines input
|
||||||
|
# line = line[:-3] # suffix """
|
||||||
|
# if line:
|
||||||
|
# content += line + "\n"
|
||||||
|
# break
|
||||||
|
# else:
|
||||||
|
# content += line + "\n"
|
||||||
|
# if content == "":
|
||||||
|
# if not config.prompt_file:
|
||||||
|
# content = "hi"
|
||||||
|
# else:
|
||||||
|
# content = open(config.prompt_file, "r").read()
|
||||||
|
# print("User: ", content)
|
||||||
|
# elif os.path.isfile(content):
|
||||||
|
# content = open(content, "r").read()
|
||||||
|
# print("User: ", content)
|
||||||
|
# messages = his_content + [{"role": "user", "content": content}]
|
||||||
|
|
||||||
|
# async def async_inference(messages):
|
||||||
|
# generated = ""
|
||||||
|
# async for token in interface.inference(messages, "local_chat"):
|
||||||
|
# generated += token
|
||||||
|
# return generated
|
||||||
|
|
||||||
|
# generated = asyncio.run(async_inference(messages))
|
||||||
|
# his_content += [
|
||||||
|
# {"role": "user", "content": content},
|
||||||
|
# {"role": "assistant", "content": generated},
|
||||||
|
# ]
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# local_chat()
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Description :
|
Description :
|
||||||
Author : Boxin Zhang, Azure-Tang
|
Author : Boxin Zhang, Azure-Tang
|
||||||
|
@ -5,20 +111,30 @@ Version : 0.1.0
|
||||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
project_dir = os.path.dirname(os.path.dirname(__file__))
|
project_dir = os.path.dirname(os.path.dirname(__file__))
|
||||||
sys.path.insert(0, project_dir)
|
sys.path.insert(0, project_dir)
|
||||||
from ktransformers.server.args import ArgumentParser
|
import torch
|
||||||
|
import logging
|
||||||
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
AutoConfig,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
GenerationConfig,
|
||||||
|
TextStreamer,
|
||||||
|
)
|
||||||
|
import json
|
||||||
|
import fire
|
||||||
|
from ktransformers.optimize.optimize import optimize_and_load_gguf
|
||||||
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
|
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
|
||||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
|
|
||||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
|
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
|
||||||
|
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
|
||||||
from ktransformers.models.modeling_llama import LlamaForCausalLM
|
from ktransformers.models.modeling_llama import LlamaForCausalLM
|
||||||
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
||||||
|
from ktransformers.util.utils import prefill_and_generate
|
||||||
from ktransformers.server.config.config import Config
|
from ktransformers.server.config.config import Config
|
||||||
|
|
||||||
custom_models = {
|
custom_models = {
|
||||||
|
@ -29,7 +145,9 @@ custom_models = {
|
||||||
"MixtralForCausalLM": MixtralForCausalLM,
|
"MixtralForCausalLM": MixtralForCausalLM,
|
||||||
}
|
}
|
||||||
|
|
||||||
ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
|
ktransformer_rules_dir = (
|
||||||
|
os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
|
||||||
|
)
|
||||||
default_optimize_rules = {
|
default_optimize_rules = {
|
||||||
"DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
|
"DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
|
||||||
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat.yaml",
|
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat.yaml",
|
||||||
|
@ -39,28 +157,85 @@ default_optimize_rules = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def local_chat():
|
def local_chat(
|
||||||
config = Config()
|
model_path: str | None = None,
|
||||||
arg_parser = ArgumentParser(config)
|
optimize_rule_path: str = None,
|
||||||
# 初始化消息
|
gguf_path: str | None = None,
|
||||||
arg_parser.parse_args()
|
max_new_tokens: int = 1000,
|
||||||
if config.backend_type == "transformers":
|
cpu_infer: int = Config().cpu_infer,
|
||||||
from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
|
use_cuda_graph: bool = True,
|
||||||
elif config.backend_type == "exllamav2":
|
prompt_file : str | None = None,
|
||||||
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface
|
mode: str = "normal",
|
||||||
elif config.backend_type == "ktransformers":
|
):
|
||||||
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface
|
|
||||||
|
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
Config().cpu_infer = cpu_infer
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
if mode == 'long_context':
|
||||||
|
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
|
||||||
|
torch.set_default_dtype(torch.float16)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"{config.backend_type} not implemented")
|
torch.set_default_dtype(config.torch_dtype)
|
||||||
interface = BackendInterface(config)
|
|
||||||
|
with torch.device("meta"):
|
||||||
|
if config.architectures[0] in custom_models:
|
||||||
|
print("using custom modeling_xxx.py.")
|
||||||
|
if (
|
||||||
|
"Qwen2Moe" in config.architectures[0]
|
||||||
|
): # Qwen2Moe must use flash_attention_2 to avoid overflow.
|
||||||
|
config._attn_implementation = "flash_attention_2"
|
||||||
|
if "Llama" in config.architectures[0]:
|
||||||
|
config._attn_implementation = "eager"
|
||||||
|
if "Mixtral" in config.architectures[0]:
|
||||||
|
config._attn_implementation = "flash_attention_2"
|
||||||
|
|
||||||
|
model = custom_models[config.architectures[0]](config)
|
||||||
|
else:
|
||||||
|
model = AutoModelForCausalLM.from_config(
|
||||||
|
config, trust_remote_code=True, attn_implementation="flash_attention_2"
|
||||||
|
)
|
||||||
|
|
||||||
|
if optimize_rule_path is None:
|
||||||
|
if config.architectures[0] in default_optimize_rules:
|
||||||
|
print("using default_optimize_rule for", config.architectures[0])
|
||||||
|
optimize_rule_path = default_optimize_rules[config.architectures[0]]
|
||||||
|
else:
|
||||||
|
optimize_rule_path = input(
|
||||||
|
"please input the path of your rule file(yaml file containing optimize rules):"
|
||||||
|
)
|
||||||
|
|
||||||
|
if gguf_path is None:
|
||||||
|
gguf_path = input(
|
||||||
|
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
|
||||||
|
)
|
||||||
|
optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model.generation_config = GenerationConfig.from_pretrained(model_path)
|
||||||
|
except:
|
||||||
|
gen_config = GenerationConfig(
|
||||||
|
max_length=128,
|
||||||
|
temperature=0.7,
|
||||||
|
top_p=0.9,
|
||||||
|
do_sample=True
|
||||||
|
)
|
||||||
|
model.generation_config = gen_config
|
||||||
|
# model.generation_config = GenerationConfig.from_pretrained(model_path)
|
||||||
|
if model.generation_config.pad_token_id is None:
|
||||||
|
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
||||||
|
model.eval()
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
system = platform.system()
|
system = platform.system()
|
||||||
if system == "Windows":
|
if system == "Windows":
|
||||||
os.system("cls")
|
os.system("cls")
|
||||||
else:
|
else:
|
||||||
os.system("clear")
|
os.system("clear")
|
||||||
# add a history chat content
|
|
||||||
his_content = []
|
|
||||||
while True:
|
while True:
|
||||||
content = input("Chat: ")
|
content = input("Chat: ")
|
||||||
if content.startswith('"""'): # prefix """
|
if content.startswith('"""'): # prefix """
|
||||||
|
@ -76,29 +251,28 @@ def local_chat():
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
content += line + "\n"
|
content += line + "\n"
|
||||||
|
|
||||||
if content == "":
|
if content == "":
|
||||||
if not config.prompt_file:
|
if prompt_file != None:
|
||||||
content = "hi"
|
content = open(prompt_file, "r").read()
|
||||||
else:
|
else:
|
||||||
content = open(config.prompt_file, "r").read()
|
content = "Please write a piece of quicksort code in C++."
|
||||||
print("User: ", content)
|
|
||||||
elif os.path.isfile(content):
|
elif os.path.isfile(content):
|
||||||
content = open(content, "r").read()
|
content = open(content, "r").read()
|
||||||
print("User: ", content)
|
messages = [{"role": "user", "content": content}]
|
||||||
messages = his_content + [{"role": "user", "content": content}]
|
input_tensor = tokenizer.apply_chat_template(
|
||||||
|
messages, add_generation_prompt=True, return_tensors="pt"
|
||||||
async def async_inference(messages):
|
)
|
||||||
generated = ""
|
if mode == 'long_context':
|
||||||
async for token in interface.inference(messages, "local_chat"):
|
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
|
||||||
generated += token
|
"please change max_seq_len in ~/.ktransformers/config.yaml"
|
||||||
return generated
|
torch.set_default_dtype(
|
||||||
|
torch.bfloat16
|
||||||
generated = asyncio.run(async_inference(messages))
|
) # TODO: Remove this, replace dtype using config
|
||||||
his_content += [
|
generated = prefill_and_generate(
|
||||||
{"role": "user", "content": content},
|
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode
|
||||||
{"role": "assistant", "content": generated},
|
)
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
local_chat()
|
fire.Fire(local_chat)
|
8
setup.py
8
setup.py
|
@ -278,13 +278,15 @@ class CMakeBuild(BuildExtension):
|
||||||
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
|
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
|
||||||
if hasattr(self, "parallel") and self.parallel:
|
if hasattr(self, "parallel") and self.parallel:
|
||||||
build_args += [f"-j{self.parallel}"]
|
build_args += [f"-j{self.parallel}"]
|
||||||
|
print("CMake args:", cmake_args)
|
||||||
build_temp = Path(ext.sourcedir) / "build"
|
build_temp = Path(ext.sourcedir) / "build"
|
||||||
if not build_temp.exists():
|
if not build_temp.exists():
|
||||||
build_temp.mkdir(parents=True)
|
build_temp.mkdir(parents=True)
|
||||||
subprocess.run(
|
result = subprocess.run(
|
||||||
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
|
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True , capture_output=True
|
||||||
)
|
)
|
||||||
|
print("Standard output:", result.stdout)
|
||||||
|
print("Standard error:", result.stderr)
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
|
["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Reference in a new issue