diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f67ba107..af39f1ac 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1059,13 +1059,9 @@ struct server_context { } void kv_cache_clear() { - SRV_DBG("%s", "clearing KV cache\n"); - - // clear the entire KV cache + SRV_DBG("%s", "clearing all KV cache\n"); llama_kv_cache_clear(ctx); - llama_send_kv_cache_clear(ctx); - clean_kv_cache = false; } @@ -1090,7 +1086,7 @@ struct server_context { llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false); } - if (llama_decode(ctx, batch) != 0) { + if (llama_decode(ctx, batch, true) != 0) { SRV_ERR("%s", "llama_decode() failed\n"); return; } @@ -2311,7 +2307,7 @@ struct server_context { 0, 0, 0, // unused }; - const int ret = llama_decode(ctx, batch_view); + const int ret = llama_decode(ctx, batch_view, true); metrics.on_decoded(slots); if (ret != 0) { diff --git a/include/llama.h b/include/llama.h index 8bb8ac50..86da593c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -957,7 +957,8 @@ extern "C" { // < 0 - error LLAMA_API int32_t llama_decode( struct llama_context * ctx, - struct llama_batch batch); + struct llama_batch batch, + bool server_mode = false); // Set the number of threads used for decoding // n_threads is the number of threads used for generation (single token) diff --git a/src/llama.cpp b/src/llama.cpp index c03d92a5..af42f79d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17849,7 +17849,7 @@ struct sync_meta { int div_factor = 1; }; -static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) { +static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta, bool align_seq_ids = false) { GGML_ASSERT(meta != nullptr); try { std::vector send_msgs; @@ -17864,19 +17864,20 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) { } if (meta->n_seq_id != nullptr) { - GGML_ASSERT(meta->n_ctx > 0); + GGML_ASSERT(meta->n_tokens > 0); send_msgs.emplace_back("n_seq_id", strlen("n_seq_id")); - send_msgs.emplace_back(meta->n_seq_id, meta->n_ctx * sizeof(int32_t)); + send_msgs.emplace_back(meta->n_seq_id, meta->n_tokens * sizeof(int32_t)); // here we assume only a single seq_id per token is needed // pack all single seq_id values into a contiguous array - llama_seq_id * all_seq_ids = (llama_seq_id *) malloc(meta->n_ctx * sizeof(llama_seq_id)); - for (uint32_t i = 0; i < meta->n_ctx; ++i) { - all_seq_ids[i] = meta->seq_id[i][0]; + llama_seq_id * all_seq_ids = (llama_seq_id *) malloc(meta->n_tokens * sizeof(llama_seq_id)); + int seq_id_offset = align_seq_ids ? 1 : 0; + for (int32_t i = 0; i < meta->n_tokens; ++i) { + all_seq_ids[i] = meta->seq_id[i][0] - seq_id_offset; } send_msgs.emplace_back("seq_id", strlen("seq_id")); - send_msgs.emplace_back(all_seq_ids, meta->n_ctx * sizeof(llama_seq_id)); + send_msgs.emplace_back(all_seq_ids, meta->n_tokens * sizeof(llama_seq_id)); free(all_seq_ids); } @@ -17966,18 +17967,18 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) { } if (key == "n_seq_id") { - GGML_ASSERT(meta->n_ctx > 0); - GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(int32_t)); - meta->n_seq_id = (int32_t *) malloc(meta->n_ctx * sizeof(int32_t)); - std::memcpy(meta->n_seq_id, data_msg.data(), meta->n_ctx * sizeof(int32_t)); + GGML_ASSERT(meta->n_tokens > 0); + GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int32_t)); + meta->n_seq_id = (int32_t *) malloc(meta->n_tokens * sizeof(int32_t)); + std::memcpy(meta->n_seq_id, data_msg.data(), meta->n_tokens * sizeof(int32_t)); } if (key == "seq_id") { - GGML_ASSERT(meta->n_ctx > 0); - GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(llama_seq_id)); + GGML_ASSERT(meta->n_tokens > 0); + GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(llama_seq_id)); const llama_seq_id * all_seq_ids = (llama_seq_id *) data_msg.data(); - meta->seq_id = (llama_seq_id **) malloc(meta->n_ctx * sizeof(llama_seq_id *)); - for (uint32_t i = 0; i < meta->n_ctx; ++i) { + meta->seq_id = (llama_seq_id **) malloc(meta->n_tokens * sizeof(llama_seq_id *)); + for (int32_t i = 0; i < meta->n_tokens; ++i) { meta->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id)); meta->seq_id[i][0] = all_seq_ids[i]; } @@ -18203,7 +18204,8 @@ static void manage_graph_tensors(struct ggml_cgraph * cgraph, int advice, bool f // static int llama_decode_internal( llama_context & lctx, - llama_batch batch_all) { // TODO: rename back to batch + llama_batch batch_all, + bool server_mode) { const auto & model = lctx.model; const auto & hparams = model.hparams; const auto & cparams = lctx.cparams; @@ -18275,16 +18277,16 @@ static int llama_decode_internal( if (meta.n_tokens > 0) { batch_all.n_tokens = meta.n_tokens; if (meta.pos != nullptr) { - batch_all.pos = (llama_pos *) malloc(cparams.n_ctx * sizeof(llama_pos)); - std::memcpy(batch_all.pos, meta.pos, cparams.n_ctx * sizeof(llama_pos)); + batch_all.pos = (llama_pos *) malloc(meta.n_ctx * sizeof(llama_pos)); + std::memcpy(batch_all.pos, meta.pos, meta.n_ctx * sizeof(llama_pos)); } if (meta.n_seq_id != nullptr) { - batch_all.n_seq_id = (int32_t *) malloc(cparams.n_ctx * sizeof(int32_t)); - std::memcpy(batch_all.n_seq_id, meta.n_seq_id, cparams.n_ctx * sizeof(int32_t)); + batch_all.n_seq_id = (int32_t *) malloc(meta.n_tokens * sizeof(int32_t)); + std::memcpy(batch_all.n_seq_id, meta.n_seq_id, meta.n_tokens * sizeof(int32_t)); } if (meta.seq_id != nullptr) { - batch_all.seq_id = (llama_seq_id **) malloc(cparams.n_ctx * sizeof(llama_seq_id *)); - for (size_t i = 0; i < cparams.n_ctx; ++i) { + batch_all.seq_id = (llama_seq_id **) malloc(meta.n_tokens * sizeof(llama_seq_id *)); + for (int32_t i = 0; i < meta.n_tokens; ++i) { batch_all.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id)); batch_all.seq_id[i][0] = meta.seq_id[i][0]; } @@ -18346,7 +18348,7 @@ static int llama_decode_internal( meta.logits = batch_all.logits; meta.all_pos_0 = batch_all.all_pos_0; meta.all_pos_1 = batch_all.all_pos_1; - llama_send_meta(*lctx.send_socket, &meta); + llama_send_meta(*lctx.send_socket, &meta, server_mode); } lctx.sbatch.from_batch(batch_all, n_embd, @@ -23484,8 +23486,9 @@ int32_t llama_encode( int32_t llama_decode( struct llama_context * ctx, - struct llama_batch batch) { - return llama_decode_internal(*ctx, batch); + struct llama_batch batch, + bool server_mode) { + return llama_decode_internal(*ctx, batch, server_mode); } void llama_synchronize(struct llama_context * ctx) {