mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-06 05:09:03 +00:00
fix seq_id mismatch between head and worker devices
This commit is contained in:
parent
fb9b1f2b00
commit
3e6d831930
3 changed files with 33 additions and 33 deletions
|
@ -1059,13 +1059,9 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
void kv_cache_clear() {
|
void kv_cache_clear() {
|
||||||
SRV_DBG("%s", "clearing KV cache\n");
|
SRV_DBG("%s", "clearing all KV cache\n");
|
||||||
|
|
||||||
// clear the entire KV cache
|
|
||||||
llama_kv_cache_clear(ctx);
|
llama_kv_cache_clear(ctx);
|
||||||
|
|
||||||
llama_send_kv_cache_clear(ctx);
|
llama_send_kv_cache_clear(ctx);
|
||||||
|
|
||||||
clean_kv_cache = false;
|
clean_kv_cache = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1090,7 +1086,7 @@ struct server_context {
|
||||||
llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
|
llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0) {
|
if (llama_decode(ctx, batch, true) != 0) {
|
||||||
SRV_ERR("%s", "llama_decode() failed\n");
|
SRV_ERR("%s", "llama_decode() failed\n");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -2311,7 +2307,7 @@ struct server_context {
|
||||||
0, 0, 0, // unused
|
0, 0, 0, // unused
|
||||||
};
|
};
|
||||||
|
|
||||||
const int ret = llama_decode(ctx, batch_view);
|
const int ret = llama_decode(ctx, batch_view, true);
|
||||||
metrics.on_decoded(slots);
|
metrics.on_decoded(slots);
|
||||||
|
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
|
|
|
@ -957,7 +957,8 @@ extern "C" {
|
||||||
// < 0 - error
|
// < 0 - error
|
||||||
LLAMA_API int32_t llama_decode(
|
LLAMA_API int32_t llama_decode(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
struct llama_batch batch);
|
struct llama_batch batch,
|
||||||
|
bool server_mode = false);
|
||||||
|
|
||||||
// Set the number of threads used for decoding
|
// Set the number of threads used for decoding
|
||||||
// n_threads is the number of threads used for generation (single token)
|
// n_threads is the number of threads used for generation (single token)
|
||||||
|
|
|
@ -17849,7 +17849,7 @@ struct sync_meta {
|
||||||
int div_factor = 1;
|
int div_factor = 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta, bool align_seq_ids = false) {
|
||||||
GGML_ASSERT(meta != nullptr);
|
GGML_ASSERT(meta != nullptr);
|
||||||
try {
|
try {
|
||||||
std::vector<zmq::message_t> send_msgs;
|
std::vector<zmq::message_t> send_msgs;
|
||||||
|
@ -17864,19 +17864,20 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (meta->n_seq_id != nullptr) {
|
if (meta->n_seq_id != nullptr) {
|
||||||
GGML_ASSERT(meta->n_ctx > 0);
|
GGML_ASSERT(meta->n_tokens > 0);
|
||||||
send_msgs.emplace_back("n_seq_id", strlen("n_seq_id"));
|
send_msgs.emplace_back("n_seq_id", strlen("n_seq_id"));
|
||||||
send_msgs.emplace_back(meta->n_seq_id, meta->n_ctx * sizeof(int32_t));
|
send_msgs.emplace_back(meta->n_seq_id, meta->n_tokens * sizeof(int32_t));
|
||||||
|
|
||||||
// here we assume only a single seq_id per token is needed
|
// here we assume only a single seq_id per token is needed
|
||||||
// pack all single seq_id values into a contiguous array
|
// pack all single seq_id values into a contiguous array
|
||||||
llama_seq_id * all_seq_ids = (llama_seq_id *) malloc(meta->n_ctx * sizeof(llama_seq_id));
|
llama_seq_id * all_seq_ids = (llama_seq_id *) malloc(meta->n_tokens * sizeof(llama_seq_id));
|
||||||
for (uint32_t i = 0; i < meta->n_ctx; ++i) {
|
int seq_id_offset = align_seq_ids ? 1 : 0;
|
||||||
all_seq_ids[i] = meta->seq_id[i][0];
|
for (int32_t i = 0; i < meta->n_tokens; ++i) {
|
||||||
|
all_seq_ids[i] = meta->seq_id[i][0] - seq_id_offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
send_msgs.emplace_back("seq_id", strlen("seq_id"));
|
send_msgs.emplace_back("seq_id", strlen("seq_id"));
|
||||||
send_msgs.emplace_back(all_seq_ids, meta->n_ctx * sizeof(llama_seq_id));
|
send_msgs.emplace_back(all_seq_ids, meta->n_tokens * sizeof(llama_seq_id));
|
||||||
free(all_seq_ids);
|
free(all_seq_ids);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17966,18 +17967,18 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (key == "n_seq_id") {
|
if (key == "n_seq_id") {
|
||||||
GGML_ASSERT(meta->n_ctx > 0);
|
GGML_ASSERT(meta->n_tokens > 0);
|
||||||
GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(int32_t));
|
GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int32_t));
|
||||||
meta->n_seq_id = (int32_t *) malloc(meta->n_ctx * sizeof(int32_t));
|
meta->n_seq_id = (int32_t *) malloc(meta->n_tokens * sizeof(int32_t));
|
||||||
std::memcpy(meta->n_seq_id, data_msg.data(), meta->n_ctx * sizeof(int32_t));
|
std::memcpy(meta->n_seq_id, data_msg.data(), meta->n_tokens * sizeof(int32_t));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (key == "seq_id") {
|
if (key == "seq_id") {
|
||||||
GGML_ASSERT(meta->n_ctx > 0);
|
GGML_ASSERT(meta->n_tokens > 0);
|
||||||
GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(llama_seq_id));
|
GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(llama_seq_id));
|
||||||
const llama_seq_id * all_seq_ids = (llama_seq_id *) data_msg.data();
|
const llama_seq_id * all_seq_ids = (llama_seq_id *) data_msg.data();
|
||||||
meta->seq_id = (llama_seq_id **) malloc(meta->n_ctx * sizeof(llama_seq_id *));
|
meta->seq_id = (llama_seq_id **) malloc(meta->n_tokens * sizeof(llama_seq_id *));
|
||||||
for (uint32_t i = 0; i < meta->n_ctx; ++i) {
|
for (int32_t i = 0; i < meta->n_tokens; ++i) {
|
||||||
meta->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id));
|
meta->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id));
|
||||||
meta->seq_id[i][0] = all_seq_ids[i];
|
meta->seq_id[i][0] = all_seq_ids[i];
|
||||||
}
|
}
|
||||||
|
@ -18203,7 +18204,8 @@ static void manage_graph_tensors(struct ggml_cgraph * cgraph, int advice, bool f
|
||||||
//
|
//
|
||||||
static int llama_decode_internal(
|
static int llama_decode_internal(
|
||||||
llama_context & lctx,
|
llama_context & lctx,
|
||||||
llama_batch batch_all) { // TODO: rename back to batch
|
llama_batch batch_all,
|
||||||
|
bool server_mode) {
|
||||||
const auto & model = lctx.model;
|
const auto & model = lctx.model;
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
const auto & cparams = lctx.cparams;
|
const auto & cparams = lctx.cparams;
|
||||||
|
@ -18275,16 +18277,16 @@ static int llama_decode_internal(
|
||||||
if (meta.n_tokens > 0) {
|
if (meta.n_tokens > 0) {
|
||||||
batch_all.n_tokens = meta.n_tokens;
|
batch_all.n_tokens = meta.n_tokens;
|
||||||
if (meta.pos != nullptr) {
|
if (meta.pos != nullptr) {
|
||||||
batch_all.pos = (llama_pos *) malloc(cparams.n_ctx * sizeof(llama_pos));
|
batch_all.pos = (llama_pos *) malloc(meta.n_ctx * sizeof(llama_pos));
|
||||||
std::memcpy(batch_all.pos, meta.pos, cparams.n_ctx * sizeof(llama_pos));
|
std::memcpy(batch_all.pos, meta.pos, meta.n_ctx * sizeof(llama_pos));
|
||||||
}
|
}
|
||||||
if (meta.n_seq_id != nullptr) {
|
if (meta.n_seq_id != nullptr) {
|
||||||
batch_all.n_seq_id = (int32_t *) malloc(cparams.n_ctx * sizeof(int32_t));
|
batch_all.n_seq_id = (int32_t *) malloc(meta.n_tokens * sizeof(int32_t));
|
||||||
std::memcpy(batch_all.n_seq_id, meta.n_seq_id, cparams.n_ctx * sizeof(int32_t));
|
std::memcpy(batch_all.n_seq_id, meta.n_seq_id, meta.n_tokens * sizeof(int32_t));
|
||||||
}
|
}
|
||||||
if (meta.seq_id != nullptr) {
|
if (meta.seq_id != nullptr) {
|
||||||
batch_all.seq_id = (llama_seq_id **) malloc(cparams.n_ctx * sizeof(llama_seq_id *));
|
batch_all.seq_id = (llama_seq_id **) malloc(meta.n_tokens * sizeof(llama_seq_id *));
|
||||||
for (size_t i = 0; i < cparams.n_ctx; ++i) {
|
for (int32_t i = 0; i < meta.n_tokens; ++i) {
|
||||||
batch_all.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id));
|
batch_all.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id));
|
||||||
batch_all.seq_id[i][0] = meta.seq_id[i][0];
|
batch_all.seq_id[i][0] = meta.seq_id[i][0];
|
||||||
}
|
}
|
||||||
|
@ -18346,7 +18348,7 @@ static int llama_decode_internal(
|
||||||
meta.logits = batch_all.logits;
|
meta.logits = batch_all.logits;
|
||||||
meta.all_pos_0 = batch_all.all_pos_0;
|
meta.all_pos_0 = batch_all.all_pos_0;
|
||||||
meta.all_pos_1 = batch_all.all_pos_1;
|
meta.all_pos_1 = batch_all.all_pos_1;
|
||||||
llama_send_meta(*lctx.send_socket, &meta);
|
llama_send_meta(*lctx.send_socket, &meta, server_mode);
|
||||||
}
|
}
|
||||||
|
|
||||||
lctx.sbatch.from_batch(batch_all, n_embd,
|
lctx.sbatch.from_batch(batch_all, n_embd,
|
||||||
|
@ -23484,8 +23486,9 @@ int32_t llama_encode(
|
||||||
|
|
||||||
int32_t llama_decode(
|
int32_t llama_decode(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
struct llama_batch batch) {
|
struct llama_batch batch,
|
||||||
return llama_decode_internal(*ctx, batch);
|
bool server_mode) {
|
||||||
|
return llama_decode_internal(*ctx, batch, server_mode);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_synchronize(struct llama_context * ctx) {
|
void llama_synchronize(struct llama_context * ctx) {
|
||||||
|
|
Loading…
Add table
Reference in a new issue