mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-05 23:39:05 +00:00
fix seq_id mismatch between head and worker devices
This commit is contained in:
parent
fb9b1f2b00
commit
3e6d831930
3 changed files with 33 additions and 33 deletions
|
@ -1059,13 +1059,9 @@ struct server_context {
|
|||
}
|
||||
|
||||
void kv_cache_clear() {
|
||||
SRV_DBG("%s", "clearing KV cache\n");
|
||||
|
||||
// clear the entire KV cache
|
||||
SRV_DBG("%s", "clearing all KV cache\n");
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
llama_send_kv_cache_clear(ctx);
|
||||
|
||||
clean_kv_cache = false;
|
||||
}
|
||||
|
||||
|
@ -1090,7 +1086,7 @@ struct server_context {
|
|||
llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
if (llama_decode(ctx, batch, true) != 0) {
|
||||
SRV_ERR("%s", "llama_decode() failed\n");
|
||||
return;
|
||||
}
|
||||
|
@ -2311,7 +2307,7 @@ struct server_context {
|
|||
0, 0, 0, // unused
|
||||
};
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
const int ret = llama_decode(ctx, batch_view, true);
|
||||
metrics.on_decoded(slots);
|
||||
|
||||
if (ret != 0) {
|
||||
|
|
|
@ -957,7 +957,8 @@ extern "C" {
|
|||
// < 0 - error
|
||||
LLAMA_API int32_t llama_decode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch);
|
||||
struct llama_batch batch,
|
||||
bool server_mode = false);
|
||||
|
||||
// Set the number of threads used for decoding
|
||||
// n_threads is the number of threads used for generation (single token)
|
||||
|
|
|
@ -17849,7 +17849,7 @@ struct sync_meta {
|
|||
int div_factor = 1;
|
||||
};
|
||||
|
||||
static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
||||
static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta, bool align_seq_ids = false) {
|
||||
GGML_ASSERT(meta != nullptr);
|
||||
try {
|
||||
std::vector<zmq::message_t> send_msgs;
|
||||
|
@ -17864,19 +17864,20 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
|||
}
|
||||
|
||||
if (meta->n_seq_id != nullptr) {
|
||||
GGML_ASSERT(meta->n_ctx > 0);
|
||||
GGML_ASSERT(meta->n_tokens > 0);
|
||||
send_msgs.emplace_back("n_seq_id", strlen("n_seq_id"));
|
||||
send_msgs.emplace_back(meta->n_seq_id, meta->n_ctx * sizeof(int32_t));
|
||||
send_msgs.emplace_back(meta->n_seq_id, meta->n_tokens * sizeof(int32_t));
|
||||
|
||||
// here we assume only a single seq_id per token is needed
|
||||
// pack all single seq_id values into a contiguous array
|
||||
llama_seq_id * all_seq_ids = (llama_seq_id *) malloc(meta->n_ctx * sizeof(llama_seq_id));
|
||||
for (uint32_t i = 0; i < meta->n_ctx; ++i) {
|
||||
all_seq_ids[i] = meta->seq_id[i][0];
|
||||
llama_seq_id * all_seq_ids = (llama_seq_id *) malloc(meta->n_tokens * sizeof(llama_seq_id));
|
||||
int seq_id_offset = align_seq_ids ? 1 : 0;
|
||||
for (int32_t i = 0; i < meta->n_tokens; ++i) {
|
||||
all_seq_ids[i] = meta->seq_id[i][0] - seq_id_offset;
|
||||
}
|
||||
|
||||
send_msgs.emplace_back("seq_id", strlen("seq_id"));
|
||||
send_msgs.emplace_back(all_seq_ids, meta->n_ctx * sizeof(llama_seq_id));
|
||||
send_msgs.emplace_back(all_seq_ids, meta->n_tokens * sizeof(llama_seq_id));
|
||||
free(all_seq_ids);
|
||||
}
|
||||
|
||||
|
@ -17966,18 +17967,18 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
|||
}
|
||||
|
||||
if (key == "n_seq_id") {
|
||||
GGML_ASSERT(meta->n_ctx > 0);
|
||||
GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(int32_t));
|
||||
meta->n_seq_id = (int32_t *) malloc(meta->n_ctx * sizeof(int32_t));
|
||||
std::memcpy(meta->n_seq_id, data_msg.data(), meta->n_ctx * sizeof(int32_t));
|
||||
GGML_ASSERT(meta->n_tokens > 0);
|
||||
GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int32_t));
|
||||
meta->n_seq_id = (int32_t *) malloc(meta->n_tokens * sizeof(int32_t));
|
||||
std::memcpy(meta->n_seq_id, data_msg.data(), meta->n_tokens * sizeof(int32_t));
|
||||
}
|
||||
|
||||
if (key == "seq_id") {
|
||||
GGML_ASSERT(meta->n_ctx > 0);
|
||||
GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(llama_seq_id));
|
||||
GGML_ASSERT(meta->n_tokens > 0);
|
||||
GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(llama_seq_id));
|
||||
const llama_seq_id * all_seq_ids = (llama_seq_id *) data_msg.data();
|
||||
meta->seq_id = (llama_seq_id **) malloc(meta->n_ctx * sizeof(llama_seq_id *));
|
||||
for (uint32_t i = 0; i < meta->n_ctx; ++i) {
|
||||
meta->seq_id = (llama_seq_id **) malloc(meta->n_tokens * sizeof(llama_seq_id *));
|
||||
for (int32_t i = 0; i < meta->n_tokens; ++i) {
|
||||
meta->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id));
|
||||
meta->seq_id[i][0] = all_seq_ids[i];
|
||||
}
|
||||
|
@ -18203,7 +18204,8 @@ static void manage_graph_tensors(struct ggml_cgraph * cgraph, int advice, bool f
|
|||
//
|
||||
static int llama_decode_internal(
|
||||
llama_context & lctx,
|
||||
llama_batch batch_all) { // TODO: rename back to batch
|
||||
llama_batch batch_all,
|
||||
bool server_mode) {
|
||||
const auto & model = lctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
const auto & cparams = lctx.cparams;
|
||||
|
@ -18275,16 +18277,16 @@ static int llama_decode_internal(
|
|||
if (meta.n_tokens > 0) {
|
||||
batch_all.n_tokens = meta.n_tokens;
|
||||
if (meta.pos != nullptr) {
|
||||
batch_all.pos = (llama_pos *) malloc(cparams.n_ctx * sizeof(llama_pos));
|
||||
std::memcpy(batch_all.pos, meta.pos, cparams.n_ctx * sizeof(llama_pos));
|
||||
batch_all.pos = (llama_pos *) malloc(meta.n_ctx * sizeof(llama_pos));
|
||||
std::memcpy(batch_all.pos, meta.pos, meta.n_ctx * sizeof(llama_pos));
|
||||
}
|
||||
if (meta.n_seq_id != nullptr) {
|
||||
batch_all.n_seq_id = (int32_t *) malloc(cparams.n_ctx * sizeof(int32_t));
|
||||
std::memcpy(batch_all.n_seq_id, meta.n_seq_id, cparams.n_ctx * sizeof(int32_t));
|
||||
batch_all.n_seq_id = (int32_t *) malloc(meta.n_tokens * sizeof(int32_t));
|
||||
std::memcpy(batch_all.n_seq_id, meta.n_seq_id, meta.n_tokens * sizeof(int32_t));
|
||||
}
|
||||
if (meta.seq_id != nullptr) {
|
||||
batch_all.seq_id = (llama_seq_id **) malloc(cparams.n_ctx * sizeof(llama_seq_id *));
|
||||
for (size_t i = 0; i < cparams.n_ctx; ++i) {
|
||||
batch_all.seq_id = (llama_seq_id **) malloc(meta.n_tokens * sizeof(llama_seq_id *));
|
||||
for (int32_t i = 0; i < meta.n_tokens; ++i) {
|
||||
batch_all.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id));
|
||||
batch_all.seq_id[i][0] = meta.seq_id[i][0];
|
||||
}
|
||||
|
@ -18346,7 +18348,7 @@ static int llama_decode_internal(
|
|||
meta.logits = batch_all.logits;
|
||||
meta.all_pos_0 = batch_all.all_pos_0;
|
||||
meta.all_pos_1 = batch_all.all_pos_1;
|
||||
llama_send_meta(*lctx.send_socket, &meta);
|
||||
llama_send_meta(*lctx.send_socket, &meta, server_mode);
|
||||
}
|
||||
|
||||
lctx.sbatch.from_batch(batch_all, n_embd,
|
||||
|
@ -23484,8 +23486,9 @@ int32_t llama_encode(
|
|||
|
||||
int32_t llama_decode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch) {
|
||||
return llama_decode_internal(*ctx, batch);
|
||||
struct llama_batch batch,
|
||||
bool server_mode) {
|
||||
return llama_decode_internal(*ctx, batch, server_mode);
|
||||
}
|
||||
|
||||
void llama_synchronize(struct llama_context * ctx) {
|
||||
|
|
Loading…
Add table
Reference in a new issue