diff --git a/src/llama.cpp b/src/llama.cpp index c25e14d9..7895b8f6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17810,6 +17810,7 @@ struct sync_meta { llama_pos all_pos_0; llama_pos all_pos_1; uint32_t n_ctx = 0; + int8_t * logits = nullptr; // signal to clear the kv cache bool clear_kv_cache = false; @@ -17862,6 +17863,12 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) { send_msgs.emplace_back("all_pos_1", strlen("all_pos_1")); send_msgs.emplace_back(&(meta->all_pos_1), sizeof(meta->all_pos_1)); + if (meta->logits != nullptr) { + GGML_ASSERT(meta->n_tokens > 0); + send_msgs.emplace_back("logits", strlen("logits")); + send_msgs.emplace_back(meta->logits, meta->n_tokens * sizeof(int8_t)); + } + zmq::send_multipart(socket, send_msgs); } catch (const zmq::error_t& e) { LLAMA_LOG_INFO("Failed to send meta data: %s\n", e.what()); @@ -17944,6 +17951,13 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) { GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_1)); std::memcpy(&(meta->all_pos_1), data_msg.data(), sizeof(meta->all_pos_1)); } + + if (key == "logits") { + GGML_ASSERT(meta->n_tokens > 0); + GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int8_t)); + meta->logits = (int8_t *) malloc(meta->n_tokens * sizeof(int8_t)); + std::memcpy(meta->logits, data_msg.data(), meta->n_tokens * sizeof(int8_t)); + } } return 0; } @@ -18225,6 +18239,10 @@ static int llama_decode_internal( } batch_all.all_pos_0 = meta.all_pos_0; batch_all.all_pos_1 = meta.all_pos_1; + if (meta.logits != nullptr) { + batch_all.logits = (int8_t *) malloc(meta.n_tokens * sizeof(int8_t)); + std::memcpy(batch_all.logits, meta.logits, meta.n_tokens * sizeof(int8_t)); + } } if (kv_cache_op(meta.clear_kv_cache, @@ -18273,6 +18291,7 @@ static int llama_decode_internal( meta.pos = batch_all.pos; meta.all_pos_0 = batch_all.all_pos_0; meta.all_pos_1 = batch_all.all_pos_1; + meta.logits = batch_all.logits; llama_send_meta(*lctx.send_socket, &meta); }