Merge branch 'speculative' into dev

This commit is contained in:
Li, Zonghang 2025-06-16 13:27:36 +04:00
commit fbbc30c950
6 changed files with 104 additions and 41 deletions

View file

@ -2,6 +2,8 @@
BUILD_TARGETS = \
llama-server \
llama-cli \
llama-speculative \
llama-gguf-split \
profile-tool
# BUILD_TARGETS = \

View file

@ -34,26 +34,26 @@ And, if your devices are more powerful, you could unlock even more possibilities
> Device D4 runs inside a Termux-simulated Linux. Device D1 reads disk data in random mode and D2~D4 read in sequential mode.
**Table 2:** Token latency for Llama models (w/o device selection).
**Table 2:** Token latency for Llama models (with device selection).
| **Model** | **llama.cpp** | **exo** | **dllama** | **prima.cpp** |
|-----------------|---------------|-----------|------------|---------------|
| Llama 3-8B | **15 ms** | 263 ms | 459 ms | 54 ms |
| Llama 3-14B | **20 ms** | - | - | 65 ms |
|----------------|---------------|-----------|------------|---------------|
| Llama 3-8B | 15 ms | 263 ms | 459 ms | **15 ms** |
| Llama 3-14B | 20 ms | - | - | **20 ms** |
| Llama 1-30B | 202 ms | - | - | **72 ms** |
| Llama 3-45B | 328 ms | - | - | **233 ms** |
| Llama 3-60B | 7965 ms | - | - | **468 ms** |
| Llama 1-65B | 8807 ms | - | - | **569 ms** |
| Llama 3-70B | 10120 ms | OOM | OOM | **674 ms** |
**Table 3:** Token latency for Qwen 2.5, QwQ, and DeepSeek R1 models (w/o device selection).
**Table 3:** Token latency for Qwen 2.5, QwQ, and DeepSeek R1 models (with device selection).
| **Model** | **llama.cpp** | **exo** | **dllama** | **prima.cpp** |
|-----------------------------------|---------------|---------------|------------|---------------|
| Qwen-2.5-7B | **14 ms** | 86 ms | - | 44 ms |
| DeepSeek-R1-Distill-Qwen-7B | **14 ms** | 68 ms | - | 52 ms |
| DeepSeek-R1-Distill-Llama-8B | **14 ms** | 77 ms | 435 ms | 59 ms |
| Qwen-2.5-14B | **23 ms** | 31710 ms | - | 65 ms |
| DeepSeek-R1-Distill-Qwen-14B | **24 ms** | 23475 ms | - | 76 ms |
| Qwen-2.5-7B | 14 ms | 86 ms | - | **14 ms** |
| DeepSeek-R1-Distill-Qwen-7B | 14 ms | 68 ms | - | **14 ms** |
| DeepSeek-R1-Distill-Llama-8B | 14 ms | 77 ms | 435 ms | **14 ms** |
| Qwen-2.5-14B | 23 ms | 31710 ms | - | **23 ms** |
| DeepSeek-R1-Distill-Qwen-14B | 24 ms | 23475 ms | - | **24 ms** |
| Qwen-2.5-32B and QwQ-32B | 224 ms | OOM | - | **89 ms** |
| DeepSeek-R1-Distill-Qwen-32B | 232 ms | OOM | - | **93 ms** |
| DeepSeek-R1-Distill-Llama-70B | 10978 ms | OOM | - | **724 ms** |
@ -61,9 +61,9 @@ And, if your devices are more powerful, you could unlock even more possibilities
> As video recording consumes some RAM, prima.cpp proactively reduces memory usage, resulting in slightly higher latency in the video compared to the table.
> In the old version (w/o device selection), each device is assigned at least one model layer. This would lead to a 1:1:29:1 split for Llama 3-8B, which makes prima.cpp slower than llama.cpp.
> ~~In the old version (w/o device selection), each device is assigned at least one model layer. This would lead to a 1:1:29:1 split for Llama 3-8B, which makes prima.cpp slower than llama.cpp.~~
>
> **New:** In the latest version (with device selection), we will have a 0:0:32:0 split and weak devices removed, then prima.cpp would become llama.cpp when serving small models.
> In the current version (with device selection), we will have a 32:0:0:0 split and weak devices removed, then prima.cpp would become llama.cpp when serving small models.
## 🔑 Key Features
@ -72,8 +72,10 @@ And, if your devices are more powerful, you could unlock even more possibilities
- - **GPU & CPU Offloading:** If a device has a GPU, you can use both GPU and CPU for inference. For example, when VRAM is full, we can offload some model layers to RAM.
- - **Piped-ring parallelism with prefetching:** Prefetch upcoming layer weights to overlap disk loading latency and use advanced piped-ring parallelism to prevent the "prefetch-release" effect. This new parallelism improves pipeline parallelism by using a ring structure and allows devices to run multiple cycles to predict a new token.
- - **Heterogeneity-aware workload distribution:** A scheduler is designed to optimize workload distribution based on each device's computing power, disk speed, memory, and OS (the OS will affect the disk speed and the memory management strategy). It decides how many model layers a device should handle and how many should run on GPU (if available).
- - **Automatic device selection:** If there are weak devices and removing them would speed up inference, prima.cpp will automatically discover and remove them.
- - **Automatic device selection:** If there are weak devices and removing them would speed up inference, prima.cpp will automatically discover and remove them. This may retain some devices as proxy to prevent the socket connection from being blocked.
- - **Quantization:** We now support Q4K, Q6K, Q80 and IQ1 quantization (GGUF format) and are exploring a Q4K-IQ1 hybrid for a better balance between performance and speed.
- - **Speculative decoding:** We now support speculative decoding, which can [further speed up by up to 80%.](https://github.com/Lizonghang/prima.cpp/discussions/29)
- **Dynamic batching**: We now support concurrent requests from multiple users and batch decoding.
- **Support Models:** We now support hot models like the **Llama, Qwen (and QwQ), and DeepSeek series**. More will be added in future updates.
- **Cross-Platform:** The cluster can consist of devices with different OSs, including macOS, Linux, Android, HarmonyOS, etc. Now, Android and HarmonyOS devices require Termux, and Windows support will be added in future update.
@ -120,6 +122,7 @@ Before using this project, ensure you have the following dependencies installed:
**Linux (e.g., Ubuntu):**
```shell
# Use apt in Linux and pkg in Termux
sudo apt update -y && sudo apt install -y gcc-9 make cmake fio git wget libzmq3-dev
```
@ -279,6 +282,8 @@ You can run prima.cpp in server mode, by launching `llama-server` on the rank 0
./llama-cli -m download/qwq-32b-q4_k_m.gguf --world 2 --rank 1 --master 192.168.1.2 --next 192.168.1.2 --prefetch
```
You can specify `-np 4 --cont-batching` when launching `llama-server` to enable concurrent requests.
After that, you can interact with the rank 0 device by calling the Chat Completion API:
```shell
@ -374,6 +379,9 @@ curl -X POST http://localhost:8080/v1/cancel \
-d '{"task_id": 0}'
```
**9. How to use speculative decoding?**
Please see "[Power prima.cpp with speculative decoding: Further speeds up by up to 80%](https://github.com/Lizonghang/prima.cpp/discussions/29)".
## ❤️ Acknowledgment
This project builds upon the incredible work from the open-source community, especially [ggml, gguf](https://github.com/ggml-org/ggml), and [llama.cpp](https://github.com/ggml-org/llama.cpp). We gratefully acknowledge their contributions.

View file

@ -1248,6 +1248,10 @@ static bool assign_layers_to_device(
return cost * k;
}
);
// apply higher priority to the head device, here 0.99 is a heuristic value
// to ensure that small models in homogeneous clusters result in 32:0 partitioning,
// rather than 1:31.
model.lp_.col_cost_[0] *= 0.99;
// define the variable bounds
model.lp_.col_lower_ = std::vector<double>(n_world * 2, 0.0);

View file

@ -12,7 +12,7 @@
#include <string>
#include <vector>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
struct seq_draft {
@ -65,23 +65,30 @@ int main(int argc, char ** argv) {
llama_context * ctx_tgt = NULL;
llama_context * ctx_dft = NULL;
// load the draft model
// make a hard copy of params to use for the draft model
gpt_params params_draft = params;
params_draft.model = params_draft.model_draft;
params_draft.n_gpu_layers = params_draft.n_gpu_layers_draft;
params_draft.n_world = 1; // do not split the draft model across devices
params_draft.rank = 0; // always load the draft model on the head device
params_draft.use_mlock = true; // always use mlock for the draft model
std::fill_n(params_draft.n_layer_window, params.n_world, 0);
if (params_draft.draft_cpuparams.n_threads > 0) {
params_draft.cpuparams.n_threads = params_draft.draft_cpuparams.n_threads;
}
params_draft.cpuparams_batch.n_threads = params_draft.draft_cpuparams_batch.n_threads;
llama_init_result llama_init_dft = llama_init_from_gpt_params(params_draft);
model_dft = llama_init_dft.model;
ctx_dft = llama_init_dft.context;
// load the target model
llama_init_result llama_init_tgt = llama_init_from_gpt_params(params);
model_tgt = llama_init_tgt.model;
ctx_tgt = llama_init_tgt.context;
// load the draft model
params.model = params.model_draft;
params.n_gpu_layers = params.n_gpu_layers_draft;
if (params.draft_cpuparams.n_threads > 0) {
params.cpuparams.n_threads = params.draft_cpuparams.n_threads;
}
params.cpuparams_batch.n_threads = params.draft_cpuparams_batch.n_threads;
llama_init_result llama_init_dft = llama_init_from_gpt_params(params);
model_dft = llama_init_dft.model;
ctx_dft = llama_init_dft.context;
const bool vocab_type_tgt = llama_vocab_type(model_tgt);
LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt);
@ -161,9 +168,6 @@ int main(int argc, char ** argv) {
const auto t_enc_end = ggml_time_us();
// the 2 models should have the same vocab
//GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
// how many tokens to draft each time
int n_draft = params.n_draft;
@ -180,8 +184,6 @@ int main(int argc, char ** argv) {
// target model sampling context (reuse the llama_context's sampling instance)
struct gpt_sampler * smpl = gpt_sampler_init(model_tgt, params.sparams);
struct llama_sampler * softmax = llama_sampler_init_softmax();
// draft sequence data
std::vector<seq_draft> drafts(n_seq_dft);
@ -258,10 +260,13 @@ int main(int argc, char ** argv) {
float r = u_dist(rng);
llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), LLAMA_TOKEN_NULL, true };
//GGML_ASSERT(dist_tgt.size <= dist_dft.size);
// if (dist_tgt.size > dist_dft.size) {
// LOG_ERR("dist_tgt.size (%zu) must be less than or equal to dist_dft.size (%zu)\n", dist_tgt.size, dist_dft.size);
// GGML_ASSERT(dist_tgt.size <= dist_dft.size);
// }
// acquire the token probabilities assigned by the draft and target models
for (size_t i = 0; i < dist_tgt.size; i++) {
for (size_t i = 0; i < dist_tgt.size && i < dist_dft.size; i++) {
if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
p_tgt = dist_tgt.data[i].p;
}
@ -406,7 +411,6 @@ int main(int argc, char ** argv) {
{
LOG_DBG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str());
// TODO: simplify
{
LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
@ -418,6 +422,12 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_keep(ctx_tgt, s_keep);
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_tgt, 0);
// notify other devices to manage the KV cache in the same way
llama_send_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
llama_send_kv_cache_seq_keep(ctx_tgt, s_keep);
llama_send_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
llama_send_kv_cache_seq_keep(ctx_tgt, 0);
}
for (int s = 0; s < n_seq_dft; ++s) {
@ -435,7 +445,6 @@ int main(int argc, char ** argv) {
llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
llama_decode(ctx_dft, batch_dft);
++n_past_dft;
@ -575,12 +584,13 @@ int main(int argc, char ** argv) {
// evaluate the target model on the drafted tokens
{
llama_kv_cache_seq_keep(ctx_tgt, 0);
llama_kv_cache_seq_keep (ctx_tgt, 0);
llama_send_kv_cache_seq_keep(ctx_tgt, 0);
for (int s = 1; s < n_seq_dft; ++s) {
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
llama_kv_cache_seq_cp (ctx_tgt, 0, s, -1, -1);
llama_send_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
}
// LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
llama_decode(ctx_tgt, batch_tgt);
++n_past_tgt;
}
@ -612,19 +622,21 @@ int main(int argc, char ** argv) {
LOG_INF("\n");
LOG_INF("draft:\n\n");
// TODO: print sampling/grammar timings for all drafts
llama_perf_context_print(ctx_dft);
LOG_INF("\n");
LOG_INF("target:\n\n");
gpt_perf_print(ctx_tgt, smpl);
char * stop_signal = nullptr;
llama_free_sockets(ctx_tgt, &stop_signal);
gpt_sampler_free(smpl);
for (int s = 0; s < n_seq_dft; ++s) {
gpt_sampler_free(drafts[s].smpl);
}
llama_sampler_free(softmax);
llama_batch_free(batch_dft);
llama_free(ctx_tgt);

View file

@ -772,6 +772,11 @@ extern "C" {
LLAMA_API void llama_kv_cache_seq_keep(
struct llama_context * ctx,
llama_seq_id seq_id);
// Notify other nodes to keep only the specified sequence in their KV cache
LLAMA_API void llama_send_kv_cache_seq_keep(
struct llama_context * ctx,
llama_seq_id seq_id);
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
// If the KV cache is RoPEd, the KV data is updated accordingly:

View file

@ -17845,6 +17845,9 @@ struct sync_meta {
llama_pos cp_p0 = 0;
llama_pos cp_p1 = 0;
bool kv_seq_keep = false;
llama_seq_id keep_seq_id = 0;
// signal to divide the kv cache range
bool kv_seq_div = false;
llama_seq_id div_seq_id = 0;
@ -17947,8 +17950,14 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
return 0;
}
if (cmd == "kv_seq_keep" && recv_msgs.size() == 2) {
meta->kv_seq_keep = true;
std::memcpy(&meta->keep_seq_id, recv_msgs[idx++].data(), sizeof(meta->keep_seq_id));
return 0;
}
if (cmd == "kv_seq_div" && recv_msgs.size() == 5) {
meta->kv_seq_div = true;
meta->kv_seq_div = true;
std::memcpy(&meta->div_seq_id, recv_msgs[idx++].data(), sizeof(meta->div_seq_id));
std::memcpy(&meta->div_p0, recv_msgs[idx++].data(), sizeof(meta->div_p0));
std::memcpy(&meta->div_p1, recv_msgs[idx++].data(), sizeof(meta->div_p1));
@ -18338,6 +18347,14 @@ static int llama_decode_internal(
return -1;
}
if (kv_cache_op(meta.kv_seq_keep,
[&]{ llama_kv_cache_seq_keep (&lctx, meta.keep_seq_id); },
[&]{ llama_send_kv_cache_seq_keep(&lctx, meta.keep_seq_id); },
is_last_dev)) {
LLAMA_LOG_DEBUG("%s: received signal kv_cache_seq_keep\n", __func__);
return -1;
}
if (kv_cache_op(meta.kv_seq_div,
[&]{ llama_kv_cache_seq_div (&lctx, meta.div_seq_id, meta.div_p0, meta.div_p1, meta.div_factor); },
[&]{ llama_send_kv_cache_seq_div(&lctx, meta.div_seq_id, meta.div_p0, meta.div_p1, meta.div_factor); },
@ -22453,6 +22470,21 @@ void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
llama_kv_cache_seq_keep(ctx->kv_self, seq_id);
}
void llama_send_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
if (ctx->send_socket == nullptr) {
return;
}
try {
std::vector<zmq::message_t> msgs;
msgs.emplace_back("kv_seq_keep", strlen("kv_seq_keep"));
msgs.emplace_back(&seq_id, sizeof(seq_id));
zmq::send_multipart(*ctx->send_socket, msgs);
} catch (const zmq::error_t & e) {
LLAMA_LOG_WARN("Failed to send kv_seq_keep: %s\n", e.what());
}
}
void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
if (delta == 0) {
return;