From 500e066a2ffd7b575cdbd1a52261bf8a818ec3b9 Mon Sep 17 00:00:00 2001 From: Lizonghang <870644199@qq.com> Date: Fri, 6 Jun 2025 16:53:22 +0400 Subject: [PATCH 1/6] fix batch decoding and dynamic batching --- common/arg.cpp | 14 ++++---- src/llama.cpp | 87 +++++++++++++++++++++++++------------------------- 2 files changed, 51 insertions(+), 50 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 3dcaa051..47d3c5e6 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1424,13 +1424,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, params.defrag_thold = std::stof(value); } ).set_env("LLAMA_ARG_DEFRAG_THOLD")); - // add_opt(llama_arg( - // {"-np", "--parallel"}, "N", - // format("number of parallel sequences to decode (default: %d)", params.n_parallel), - // [](gpt_params & params, int value) { - // params.n_parallel = value; - // } - // ).set_env("LLAMA_ARG_N_PARALLEL")); + add_opt(llama_arg( + {"-np", "--parallel"}, "N", + format("number of parallel sequences to decode (default: %d)", params.n_parallel), + [](gpt_params & params, int value) { + params.n_parallel = value; + } + ).set_env("LLAMA_ARG_N_PARALLEL")); add_opt(llama_arg( {"-ns", "--sequences"}, "N", format("number of sequences to decode (default: %d)", params.n_sequences), diff --git a/src/llama.cpp b/src/llama.cpp index 7129544b..c25e14d9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2782,7 +2782,6 @@ struct llama_layer { // but has more metadata about sequences struct llama_ubatch { bool equal_seqs; - // TODO: whole_seqs for embeddings? uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) uint32_t n_seq_tokens; // tokens per sequence @@ -2796,6 +2795,9 @@ struct llama_ubatch { int32_t * n_seq_id; // [n_seqs] llama_seq_id ** seq_id; // [n_seqs] int8_t * output; // [n_tokens] + + bool activate_input; + bool activate_output; }; struct llama_kv_cell { @@ -3040,7 +3042,7 @@ struct llama_sbatch { ubatch_token.resize(!has_embd ? n_ubatch : 0); ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0); ubatch_backend_embd.resize(n_embd * n_tokens); - ubatch_out_embd.resize(n_embd); + ubatch_out_embd.resize(n_embd * n_tokens); ubatch_pos.resize(n_ubatch); ubatch_n_seq_id.resize(n_ubatch); ubatch_seq_id.resize(n_ubatch); @@ -3058,6 +3060,8 @@ struct llama_sbatch { /*n_seq_id =*/ ubatch_n_seq_id.data(), /*seq_id =*/ ubatch_seq_id.data(), /*output =*/ ubatch_output.data(), + /*activate_input =*/ true, + /*activate_output =*/ false, }; return ubatch; } @@ -11104,7 +11108,6 @@ struct llm_build_context { if (il == n_layer - 1) { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - n_tokens = n_outputs; cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -16978,7 +16981,7 @@ static std::vector llama_build_graph( llm.init(); - GGML_ASSERT((model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_QWEN2) && "this model is currently not supported"); + GGML_ASSERT((model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_QWEN2) && "this model is currently not supported.\n"); switch (model.arch) { case LLM_ARCH_LLAMA: @@ -17261,31 +17264,32 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { const auto & cparams = lctx.cparams; const auto & kv_self = lctx.kv_self; - if (batch.token) { + if (batch.activate_input) { const int64_t n_tokens = batch.n_tokens; - ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens)); - } + if (batch.token) { + const int64_t size_ = n_tokens * ggml_element_size(lctx.inp_tokens); + ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, size_); + } - if (batch.embd) { - const int64_t n_embd = hparams.n_embd; - const int64_t n_tokens = batch.n_tokens; - - ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd)); - } - - if (batch.backend_embd && lctx.backend_embd && lctx.backend_embd->data != nullptr) { + if (batch.embd) { + const int64_t n_embd = hparams.n_embd; + const int64_t size_ = n_tokens * n_embd * ggml_element_size(lctx.inp_embd); + ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, size_); + } + } else if (batch.activate_output) { + if (batch.out_embd && lctx.out_embd) { + const int64_t n_embd = lctx.out_embd->ne[0]; + const int64_t n_output = lctx.out_embd->ne[1]; + const int64_t size_ = n_output * n_embd * ggml_element_size(lctx.out_embd); + ggml_backend_tensor_set(lctx.out_embd, batch.out_embd, 0, size_); + } + } else { + GGML_ASSERT(batch.backend_embd && lctx.backend_embd && lctx.backend_embd->data != nullptr); const int64_t n_embd = lctx.backend_embd->ne[0]; const int64_t n_tokens = lctx.backend_embd->ne[1]; - - ggml_backend_tensor_set(lctx.backend_embd, batch.backend_embd, 0, n_tokens*n_embd*ggml_element_size(lctx.backend_embd)); - } - - if (batch.out_embd && lctx.out_embd) { - const int64_t n_embd = lctx.out_embd->ne[0]; - const int64_t n_output = lctx.out_embd->ne[1]; - - ggml_backend_tensor_set(lctx.out_embd, batch.out_embd, 0, n_output*n_embd*ggml_element_size(lctx.out_embd)); + const int64_t size_ = n_tokens * n_embd * ggml_element_size(lctx.backend_embd); + ggml_backend_tensor_set(lctx.backend_embd, batch.backend_embd, 0, size_); } if (batch.pos && lctx.inp_pos) { @@ -17971,6 +17975,7 @@ static void llama_recv_tensors(zmq::socket_t & socket, struct llama_ubatch * uba std::vector recv_msgs; if (!zmq::recv_multipart(socket, std::back_inserter(recv_msgs))) { LLAMA_LOG_INFO("Failed to receive tensor data.\n"); + return; } for (size_t i = 0; i < recv_msgs.size(); i += 3) { @@ -18281,8 +18286,7 @@ static int llama_decode_internal( return -2; }; - { // assume there is only one batch - // while (lctx.sbatch.n_tokens > 0) { // handle multiple batches + while (lctx.sbatch.n_tokens > 0) { // handle multiple batches llama_ubatch ubatch; if (kv_self.recurrent) { if (embd_pooled) { @@ -18300,26 +18304,19 @@ static int llama_decode_internal( // count the outputs in this u_batch int32_t n_outputs_new = 0; - - if (my_rank == 0) { - if (n_outputs == n_tokens_all) { - n_outputs_new = n_tokens; - } else { - GGML_ASSERT(ubatch.output); - for (uint32_t i = 0; i < n_tokens; i++) { - n_outputs_new += (int32_t) (ubatch.output[i] != 0); - } - } + if (n_outputs == n_tokens_all) { + n_outputs_new = n_tokens; } else { - n_outputs_new += 1; + GGML_ASSERT(ubatch.output); + for (uint32_t i = 0; i < n_tokens; i++) { + n_outputs_new += (int32_t) (ubatch.output[i] != 0); + } } - // needs to happen before the graph is built lctx.n_outputs = n_outputs_new; int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch; - GGML_ASSERT(n_threads > 0); // non-causal masks do not use the KV cache @@ -18394,11 +18391,11 @@ static int llama_decode_internal( GGML_ASSERT(my_rank == 0 || n_world > 1); for (size_t i = 0; i < (size_t)gf.size(); ++i) { + const bool is_out_embd = my_rank == 0 && i == (size_t)gf.size() - 1; sub_gf = gf[i]; // receive data from other nodes if (n_world > 1 && !(my_rank == 0 && i == 0) && !(my_rank == 0 && is_last_l)) { - const bool is_out_embd = my_rank == 0 && i == (size_t)gf.size() - 1; llama_recv_tensors(*lctx.recv_socket, &ubatch, is_out_embd); } @@ -18407,6 +18404,10 @@ static int llama_decode_internal( ggml_backend_sched_synchronize(lctx.sched[i - 1]); } + ubatch.activate_input = (my_rank == 0 && i == 0); + ubatch.activate_output = (my_rank == 0 && is_out_embd); + GGML_ASSERT(!(ubatch.activate_input && ubatch.activate_output)); + llama_set_inputs(lctx, ubatch); { // compute graph @@ -18442,13 +18443,13 @@ static int llama_decode_internal( GGML_ASSERT(buf_size <= ggml_nbytes(sub_gf_out)); GGML_ASSERT(backend != nullptr); ggml_backend_tensor_get_async(backend, sub_gf_out, embd_buf, 0, buf_size); + ggml_backend_sched_synchronize(lctx.sched[i]); // send the result to the next node or the master if (!(n_world == 1 || (my_rank == 0 && is_last_l))) { struct input_tensors tensors = {sub_gf_out, lctx.inp_pos}; const bool is_to_master = my_rank != 0 && is_last_l; zmq::socket_t * s = is_to_master ? lctx.master_socket : lctx.send_socket; - ggml_backend_sched_synchronize(lctx.sched[i]); llama_send_tensors(*s, &ubatch, &tensors); } @@ -19038,7 +19039,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { uint32_t n_seqs = 1; // TODO: worst-case number of sequences uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch); llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, true, false}; std::vector gf = llama_build_graph(lctx, ubatch, true); GGML_ASSERT(lctx.sched.size() == gf.size()); @@ -21115,7 +21116,7 @@ void * llama_context_setup_backend( uint32_t n_seqs = 1; // TODO: worst-case number of sequences uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, true, false}; std::vector gf = llama_build_graph(*ctx, ubatch, true); GGML_ASSERT(gf.size() <= MAX_SCHEDULERS && "Number of subgraphs exceeds the maximum number of schedulers\n"); From 68ecc8509dc9d42fd5de7863c89b381618ec29b8 Mon Sep 17 00:00:00 2001 From: Lizonghang <870644199@qq.com> Date: Fri, 6 Jun 2025 22:58:48 +0400 Subject: [PATCH 2/6] add batch_all.logits to sync_meta --- src/llama.cpp | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/llama.cpp b/src/llama.cpp index c25e14d9..7895b8f6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17810,6 +17810,7 @@ struct sync_meta { llama_pos all_pos_0; llama_pos all_pos_1; uint32_t n_ctx = 0; + int8_t * logits = nullptr; // signal to clear the kv cache bool clear_kv_cache = false; @@ -17862,6 +17863,12 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) { send_msgs.emplace_back("all_pos_1", strlen("all_pos_1")); send_msgs.emplace_back(&(meta->all_pos_1), sizeof(meta->all_pos_1)); + if (meta->logits != nullptr) { + GGML_ASSERT(meta->n_tokens > 0); + send_msgs.emplace_back("logits", strlen("logits")); + send_msgs.emplace_back(meta->logits, meta->n_tokens * sizeof(int8_t)); + } + zmq::send_multipart(socket, send_msgs); } catch (const zmq::error_t& e) { LLAMA_LOG_INFO("Failed to send meta data: %s\n", e.what()); @@ -17944,6 +17951,13 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) { GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_1)); std::memcpy(&(meta->all_pos_1), data_msg.data(), sizeof(meta->all_pos_1)); } + + if (key == "logits") { + GGML_ASSERT(meta->n_tokens > 0); + GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int8_t)); + meta->logits = (int8_t *) malloc(meta->n_tokens * sizeof(int8_t)); + std::memcpy(meta->logits, data_msg.data(), meta->n_tokens * sizeof(int8_t)); + } } return 0; } @@ -18225,6 +18239,10 @@ static int llama_decode_internal( } batch_all.all_pos_0 = meta.all_pos_0; batch_all.all_pos_1 = meta.all_pos_1; + if (meta.logits != nullptr) { + batch_all.logits = (int8_t *) malloc(meta.n_tokens * sizeof(int8_t)); + std::memcpy(batch_all.logits, meta.logits, meta.n_tokens * sizeof(int8_t)); + } } if (kv_cache_op(meta.clear_kv_cache, @@ -18273,6 +18291,7 @@ static int llama_decode_internal( meta.pos = batch_all.pos; meta.all_pos_0 = batch_all.all_pos_0; meta.all_pos_1 = batch_all.all_pos_1; + meta.logits = batch_all.logits; llama_send_meta(*lctx.send_socket, &meta); } From a1a22388310a682e9bb95514ed95b175cbc62e76 Mon Sep 17 00:00:00 2001 From: Lizonghang <870644199@qq.com> Date: Fri, 6 Jun 2025 23:36:53 +0400 Subject: [PATCH 3/6] add batch_all.n_seq_id and batch_all.seq_id to sync_meta --- src/llama.cpp | 97 +++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 78 insertions(+), 19 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 7895b8f6..9b1d2a11 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17805,13 +17805,16 @@ struct input_tensors { }; struct sync_meta { - int32_t n_tokens = 0; - llama_pos * pos = nullptr; + int32_t n_tokens = 0; + llama_pos * pos = nullptr; + int32_t * n_seq_id = nullptr; + llama_seq_id ** seq_id = nullptr; + int8_t * logits = nullptr; + llama_pos all_pos_0; llama_pos all_pos_1; - uint32_t n_ctx = 0; - int8_t * logits = nullptr; - + uint32_t n_ctx = 0; + // signal to clear the kv cache bool clear_kv_cache = false; @@ -17857,11 +17860,19 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) { send_msgs.emplace_back(meta->pos, meta->n_ctx * sizeof(llama_pos)); } - send_msgs.emplace_back("all_pos_0", strlen("all_pos_0")); - send_msgs.emplace_back(&(meta->all_pos_0), sizeof(meta->all_pos_0)); + if (meta->n_seq_id != nullptr) { + GGML_ASSERT(meta->seq_id != nullptr); + send_msgs.emplace_back("n_seq_id", strlen("n_seq_id")); + send_msgs.emplace_back(meta->n_seq_id, meta->n_ctx * sizeof(int32_t)); - send_msgs.emplace_back("all_pos_1", strlen("all_pos_1")); - send_msgs.emplace_back(&(meta->all_pos_1), sizeof(meta->all_pos_1)); + for (size_t i = 0; i < meta->n_ctx; ++i) { + const size_t n_seq = meta->n_seq_id[i]; + if (n_seq > 0) { + send_msgs.emplace_back("seq_id", strlen("seq_id")); + send_msgs.emplace_back(meta->seq_id[i], n_seq * sizeof(llama_seq_id)); + } + } + } if (meta->logits != nullptr) { GGML_ASSERT(meta->n_tokens > 0); @@ -17869,6 +17880,12 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) { send_msgs.emplace_back(meta->logits, meta->n_tokens * sizeof(int8_t)); } + send_msgs.emplace_back("all_pos_0", strlen("all_pos_0")); + send_msgs.emplace_back(&(meta->all_pos_0), sizeof(meta->all_pos_0)); + + send_msgs.emplace_back("all_pos_1", strlen("all_pos_1")); + send_msgs.emplace_back(&(meta->all_pos_1), sizeof(meta->all_pos_1)); + zmq::send_multipart(socket, send_msgs); } catch (const zmq::error_t& e) { LLAMA_LOG_INFO("Failed to send meta data: %s\n", e.what()); @@ -17928,6 +17945,7 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) { return 0; } + size_t seq_idx = 0; for (size_t i = 0; i < recv_msgs.size(); i += 2) { std::string key = recv_msgs[i].to_string(); zmq::message_t & data_msg = recv_msgs[i + 1]; @@ -17942,6 +17960,41 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) { std::memcpy(meta->pos, data_msg.data(), meta->n_ctx * sizeof(llama_pos)); } + if (key == "n_seq_id") { + GGML_ASSERT(meta->n_ctx > 0); + GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(int32_t)); + meta->n_seq_id = (int32_t *) malloc(meta->n_ctx * sizeof(int32_t)); + memcpy(meta->n_seq_id, data_msg.data(), meta->n_ctx * sizeof(int32_t)); + + meta->seq_id = (llama_seq_id **) malloc(meta->n_ctx * sizeof(llama_seq_id *)); + for (size_t j = 0; j < meta->n_ctx; ++j) { + meta->seq_id[j] = nullptr; + } + seq_idx = 0; + } + + if (key == "seq_id") { + while (seq_idx < meta->n_ctx && meta->n_seq_id[seq_idx] <= 0) { + ++seq_idx; + } + GGML_ASSERT(seq_idx < meta->n_ctx); + + const size_t n_seq = meta->n_seq_id[seq_idx]; + GGML_ASSERT(data_msg.size() == n_seq * sizeof(llama_seq_id)); + + meta->seq_id[seq_idx] = (llama_seq_id *) malloc(n_seq * sizeof(llama_seq_id)); + memcpy(meta->seq_id[seq_idx], data_msg.data(), n_seq * sizeof(llama_seq_id)); + + ++seq_idx; + } + + if (key == "logits") { + GGML_ASSERT(meta->n_tokens > 0); + GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int8_t)); + meta->logits = (int8_t *) malloc(meta->n_tokens * sizeof(int8_t)); + std::memcpy(meta->logits, data_msg.data(), meta->n_tokens * sizeof(int8_t)); + } + if (key == "all_pos_0") { GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_0)); std::memcpy(&(meta->all_pos_0), data_msg.data(), sizeof(meta->all_pos_0)); @@ -17951,13 +18004,6 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) { GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_1)); std::memcpy(&(meta->all_pos_1), data_msg.data(), sizeof(meta->all_pos_1)); } - - if (key == "logits") { - GGML_ASSERT(meta->n_tokens > 0); - GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int8_t)); - meta->logits = (int8_t *) malloc(meta->n_tokens * sizeof(int8_t)); - std::memcpy(meta->logits, data_msg.data(), meta->n_tokens * sizeof(int8_t)); - } } return 0; } @@ -18237,12 +18283,23 @@ static int llama_decode_internal( batch_all.pos = (llama_pos *) malloc(cparams.n_ctx * sizeof(llama_pos)); std::memcpy(batch_all.pos, meta.pos, cparams.n_ctx * sizeof(llama_pos)); } - batch_all.all_pos_0 = meta.all_pos_0; - batch_all.all_pos_1 = meta.all_pos_1; + if (meta.n_seq_id != nullptr) { + batch_all.n_seq_id = (int32_t *) malloc(cparams.n_ctx * sizeof(int32_t)); + std::memcpy(batch_all.n_seq_id, meta.n_seq_id, cparams.n_ctx * sizeof(int32_t)); + } + if (meta.seq_id != nullptr) { + batch_all.seq_id = (llama_seq_id **) malloc(cparams.n_ctx * sizeof(llama_seq_id *)); + for (size_t i = 0; i < cparams.n_ctx; ++i) { + batch_all.seq_id[i] = (llama_seq_id *) malloc(meta.n_seq_id[i] * sizeof(llama_seq_id)); + std::memcpy(batch_all.seq_id[i], meta.seq_id[i], meta.n_seq_id[i] * sizeof(llama_seq_id)); + } + } if (meta.logits != nullptr) { batch_all.logits = (int8_t *) malloc(meta.n_tokens * sizeof(int8_t)); std::memcpy(batch_all.logits, meta.logits, meta.n_tokens * sizeof(int8_t)); } + batch_all.all_pos_0 = meta.all_pos_0; + batch_all.all_pos_1 = meta.all_pos_1; } if (kv_cache_op(meta.clear_kv_cache, @@ -18289,9 +18346,11 @@ static int llama_decode_internal( if (!is_last_dev) { meta.n_tokens = batch_all.n_tokens; meta.pos = batch_all.pos; + meta.n_seq_id = batch_all.n_seq_id; + meta.seq_id = batch_all.seq_id; + meta.logits = batch_all.logits; meta.all_pos_0 = batch_all.all_pos_0; meta.all_pos_1 = batch_all.all_pos_1; - meta.logits = batch_all.logits; llama_send_meta(*lctx.send_socket, &meta); } From d8aea899d1675d0dff3bc608019f44403be969ac Mon Sep 17 00:00:00 2001 From: Lizonghang <870644199@qq.com> Date: Fri, 6 Jun 2025 23:58:03 +0400 Subject: [PATCH 4/6] fix n_seq_id and seq_id --- src/llama.cpp | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 9b1d2a11..4ab1317a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17861,16 +17861,14 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) { } if (meta->n_seq_id != nullptr) { - GGML_ASSERT(meta->seq_id != nullptr); send_msgs.emplace_back("n_seq_id", strlen("n_seq_id")); - send_msgs.emplace_back(meta->n_seq_id, meta->n_ctx * sizeof(int32_t)); + send_msgs.emplace_back(meta->n_seq_id, meta->n_tokens * sizeof(int32_t)); - for (size_t i = 0; i < meta->n_ctx; ++i) { + for (int32_t i = 0; i < meta->n_tokens; ++i) { + GGML_ASSERT(meta->seq_id[i] != nullptr); const size_t n_seq = meta->n_seq_id[i]; - if (n_seq > 0) { - send_msgs.emplace_back("seq_id", strlen("seq_id")); - send_msgs.emplace_back(meta->seq_id[i], n_seq * sizeof(llama_seq_id)); - } + send_msgs.emplace_back("seq_id", strlen("seq_id")); + send_msgs.emplace_back(meta->seq_id[i], n_seq * sizeof(llama_seq_id)); } } @@ -17945,7 +17943,7 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) { return 0; } - size_t seq_idx = 0; + int32_t seq_idx = 0; for (size_t i = 0; i < recv_msgs.size(); i += 2) { std::string key = recv_msgs[i].to_string(); zmq::message_t & data_msg = recv_msgs[i + 1]; @@ -17961,23 +17959,22 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) { } if (key == "n_seq_id") { - GGML_ASSERT(meta->n_ctx > 0); - GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(int32_t)); - meta->n_seq_id = (int32_t *) malloc(meta->n_ctx * sizeof(int32_t)); - memcpy(meta->n_seq_id, data_msg.data(), meta->n_ctx * sizeof(int32_t)); + GGML_ASSERT(meta->n_tokens > 0); + GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int32_t)); + meta->n_seq_id = (int32_t *) malloc(meta->n_tokens * sizeof(int32_t)); + memcpy(meta->n_seq_id, data_msg.data(), meta->n_tokens * sizeof(int32_t)); - meta->seq_id = (llama_seq_id **) malloc(meta->n_ctx * sizeof(llama_seq_id *)); - for (size_t j = 0; j < meta->n_ctx; ++j) { + meta->seq_id = (llama_seq_id **) malloc(meta->n_tokens * sizeof(llama_seq_id *)); + for (int32_t j = 0; j < meta->n_tokens; ++j) { meta->seq_id[j] = nullptr; } seq_idx = 0; } if (key == "seq_id") { - while (seq_idx < meta->n_ctx && meta->n_seq_id[seq_idx] <= 0) { + while (seq_idx < meta->n_tokens && meta->n_seq_id[seq_idx] <= 0) { ++seq_idx; } - GGML_ASSERT(seq_idx < meta->n_ctx); const size_t n_seq = meta->n_seq_id[seq_idx]; GGML_ASSERT(data_msg.size() == n_seq * sizeof(llama_seq_id)); From e56be76bdf6a2b6b12eb126bd592656767d5052f Mon Sep 17 00:00:00 2001 From: Lizonghang <870644199@qq.com> Date: Sat, 7 Jun 2025 00:42:44 +0400 Subject: [PATCH 5/6] assume only a single seq_id per token is needed --- src/llama.cpp | 49 ++++++++++++++++++++++--------------------------- 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 4ab1317a..5531ddcc 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17861,15 +17861,20 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) { } if (meta->n_seq_id != nullptr) { + GGML_ASSERT(meta->n_ctx > 0); send_msgs.emplace_back("n_seq_id", strlen("n_seq_id")); - send_msgs.emplace_back(meta->n_seq_id, meta->n_tokens * sizeof(int32_t)); + send_msgs.emplace_back(meta->n_seq_id, meta->n_ctx * sizeof(int32_t)); - for (int32_t i = 0; i < meta->n_tokens; ++i) { - GGML_ASSERT(meta->seq_id[i] != nullptr); - const size_t n_seq = meta->n_seq_id[i]; - send_msgs.emplace_back("seq_id", strlen("seq_id")); - send_msgs.emplace_back(meta->seq_id[i], n_seq * sizeof(llama_seq_id)); + // here we assume only a single seq_id per token is needed + // pack all single seq_id values into a contiguous array + llama_seq_id * all_seq_ids = (llama_seq_id *) malloc(meta->n_ctx * sizeof(llama_seq_id)); + for (uint32_t i = 0; i < meta->n_ctx; ++i) { + all_seq_ids[i] = meta->seq_id[i][0]; } + + send_msgs.emplace_back("seq_id", strlen("seq_id")); + send_msgs.emplace_back(all_seq_ids, meta->n_ctx * sizeof(llama_seq_id)); + free(all_seq_ids); } if (meta->logits != nullptr) { @@ -17943,7 +17948,6 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) { return 0; } - int32_t seq_idx = 0; for (size_t i = 0; i < recv_msgs.size(); i += 2) { std::string key = recv_msgs[i].to_string(); zmq::message_t & data_msg = recv_msgs[i + 1]; @@ -17959,30 +17963,21 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) { } if (key == "n_seq_id") { - GGML_ASSERT(meta->n_tokens > 0); - GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int32_t)); - meta->n_seq_id = (int32_t *) malloc(meta->n_tokens * sizeof(int32_t)); - memcpy(meta->n_seq_id, data_msg.data(), meta->n_tokens * sizeof(int32_t)); - - meta->seq_id = (llama_seq_id **) malloc(meta->n_tokens * sizeof(llama_seq_id *)); - for (int32_t j = 0; j < meta->n_tokens; ++j) { - meta->seq_id[j] = nullptr; - } - seq_idx = 0; + GGML_ASSERT(meta->n_ctx > 0); + GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(int32_t)); + meta->n_seq_id = (int32_t *) malloc(meta->n_ctx * sizeof(int32_t)); + std::memcpy(meta->n_seq_id, data_msg.data(), meta->n_ctx * sizeof(int32_t)); } if (key == "seq_id") { - while (seq_idx < meta->n_tokens && meta->n_seq_id[seq_idx] <= 0) { - ++seq_idx; + GGML_ASSERT(meta->n_ctx > 0); + GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(llama_seq_id)); + const llama_seq_id * all_seq_ids = (llama_seq_id *) data_msg.data(); + meta->seq_id = (llama_seq_id **) malloc(meta->n_ctx * sizeof(llama_seq_id *)); + for (uint32_t i = 0; i < meta->n_ctx; ++i) { + meta->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id)); + meta->seq_id[i][0] = all_seq_ids[i]; } - - const size_t n_seq = meta->n_seq_id[seq_idx]; - GGML_ASSERT(data_msg.size() == n_seq * sizeof(llama_seq_id)); - - meta->seq_id[seq_idx] = (llama_seq_id *) malloc(n_seq * sizeof(llama_seq_id)); - memcpy(meta->seq_id[seq_idx], data_msg.data(), n_seq * sizeof(llama_seq_id)); - - ++seq_idx; } if (key == "logits") { From 22a6ddef13c3611ac474ebb3cfe85994b52f6616 Mon Sep 17 00:00:00 2001 From: "Li, Zonghang" <870644199@qq.com> Date: Sat, 7 Jun 2025 00:53:56 +0400 Subject: [PATCH 6/6] fix batch decoding and dynamic batching --- src/llama.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 5531ddcc..0e615b67 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -18282,8 +18282,8 @@ static int llama_decode_internal( if (meta.seq_id != nullptr) { batch_all.seq_id = (llama_seq_id **) malloc(cparams.n_ctx * sizeof(llama_seq_id *)); for (size_t i = 0; i < cparams.n_ctx; ++i) { - batch_all.seq_id[i] = (llama_seq_id *) malloc(meta.n_seq_id[i] * sizeof(llama_seq_id)); - std::memcpy(batch_all.seq_id[i], meta.seq_id[i], meta.n_seq_id[i] * sizeof(llama_seq_id)); + batch_all.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id)); + batch_all.seq_id[i][0] = meta.seq_id[i][0]; } } if (meta.logits != nullptr) {