fix seq_id mismatch between head and worker devices

This commit is contained in:
Li, Zonghang 2025-06-11 17:10:21 +04:00
parent fb9b1f2b00
commit 3e6d831930
3 changed files with 33 additions and 33 deletions

View file

@ -1059,13 +1059,9 @@ struct server_context {
}
void kv_cache_clear() {
SRV_DBG("%s", "clearing KV cache\n");
// clear the entire KV cache
SRV_DBG("%s", "clearing all KV cache\n");
llama_kv_cache_clear(ctx);
llama_send_kv_cache_clear(ctx);
clean_kv_cache = false;
}
@ -1090,7 +1086,7 @@ struct server_context {
llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
}
if (llama_decode(ctx, batch) != 0) {
if (llama_decode(ctx, batch, true) != 0) {
SRV_ERR("%s", "llama_decode() failed\n");
return;
}
@ -2311,7 +2307,7 @@ struct server_context {
0, 0, 0, // unused
};
const int ret = llama_decode(ctx, batch_view);
const int ret = llama_decode(ctx, batch_view, true);
metrics.on_decoded(slots);
if (ret != 0) {

View file

@ -957,7 +957,8 @@ extern "C" {
// < 0 - error
LLAMA_API int32_t llama_decode(
struct llama_context * ctx,
struct llama_batch batch);
struct llama_batch batch,
bool server_mode = false);
// Set the number of threads used for decoding
// n_threads is the number of threads used for generation (single token)

View file

@ -17849,7 +17849,7 @@ struct sync_meta {
int div_factor = 1;
};
static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta, bool align_seq_ids = false) {
GGML_ASSERT(meta != nullptr);
try {
std::vector<zmq::message_t> send_msgs;
@ -17864,19 +17864,20 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
}
if (meta->n_seq_id != nullptr) {
GGML_ASSERT(meta->n_ctx > 0);
GGML_ASSERT(meta->n_tokens > 0);
send_msgs.emplace_back("n_seq_id", strlen("n_seq_id"));
send_msgs.emplace_back(meta->n_seq_id, meta->n_ctx * sizeof(int32_t));
send_msgs.emplace_back(meta->n_seq_id, meta->n_tokens * sizeof(int32_t));
// here we assume only a single seq_id per token is needed
// pack all single seq_id values into a contiguous array
llama_seq_id * all_seq_ids = (llama_seq_id *) malloc(meta->n_ctx * sizeof(llama_seq_id));
for (uint32_t i = 0; i < meta->n_ctx; ++i) {
all_seq_ids[i] = meta->seq_id[i][0];
llama_seq_id * all_seq_ids = (llama_seq_id *) malloc(meta->n_tokens * sizeof(llama_seq_id));
int seq_id_offset = align_seq_ids ? 1 : 0;
for (int32_t i = 0; i < meta->n_tokens; ++i) {
all_seq_ids[i] = meta->seq_id[i][0] - seq_id_offset;
}
send_msgs.emplace_back("seq_id", strlen("seq_id"));
send_msgs.emplace_back(all_seq_ids, meta->n_ctx * sizeof(llama_seq_id));
send_msgs.emplace_back(all_seq_ids, meta->n_tokens * sizeof(llama_seq_id));
free(all_seq_ids);
}
@ -17966,18 +17967,18 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
}
if (key == "n_seq_id") {
GGML_ASSERT(meta->n_ctx > 0);
GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(int32_t));
meta->n_seq_id = (int32_t *) malloc(meta->n_ctx * sizeof(int32_t));
std::memcpy(meta->n_seq_id, data_msg.data(), meta->n_ctx * sizeof(int32_t));
GGML_ASSERT(meta->n_tokens > 0);
GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int32_t));
meta->n_seq_id = (int32_t *) malloc(meta->n_tokens * sizeof(int32_t));
std::memcpy(meta->n_seq_id, data_msg.data(), meta->n_tokens * sizeof(int32_t));
}
if (key == "seq_id") {
GGML_ASSERT(meta->n_ctx > 0);
GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(llama_seq_id));
GGML_ASSERT(meta->n_tokens > 0);
GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(llama_seq_id));
const llama_seq_id * all_seq_ids = (llama_seq_id *) data_msg.data();
meta->seq_id = (llama_seq_id **) malloc(meta->n_ctx * sizeof(llama_seq_id *));
for (uint32_t i = 0; i < meta->n_ctx; ++i) {
meta->seq_id = (llama_seq_id **) malloc(meta->n_tokens * sizeof(llama_seq_id *));
for (int32_t i = 0; i < meta->n_tokens; ++i) {
meta->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id));
meta->seq_id[i][0] = all_seq_ids[i];
}
@ -18203,7 +18204,8 @@ static void manage_graph_tensors(struct ggml_cgraph * cgraph, int advice, bool f
//
static int llama_decode_internal(
llama_context & lctx,
llama_batch batch_all) { // TODO: rename back to batch
llama_batch batch_all,
bool server_mode) {
const auto & model = lctx.model;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
@ -18275,16 +18277,16 @@ static int llama_decode_internal(
if (meta.n_tokens > 0) {
batch_all.n_tokens = meta.n_tokens;
if (meta.pos != nullptr) {
batch_all.pos = (llama_pos *) malloc(cparams.n_ctx * sizeof(llama_pos));
std::memcpy(batch_all.pos, meta.pos, cparams.n_ctx * sizeof(llama_pos));
batch_all.pos = (llama_pos *) malloc(meta.n_ctx * sizeof(llama_pos));
std::memcpy(batch_all.pos, meta.pos, meta.n_ctx * sizeof(llama_pos));
}
if (meta.n_seq_id != nullptr) {
batch_all.n_seq_id = (int32_t *) malloc(cparams.n_ctx * sizeof(int32_t));
std::memcpy(batch_all.n_seq_id, meta.n_seq_id, cparams.n_ctx * sizeof(int32_t));
batch_all.n_seq_id = (int32_t *) malloc(meta.n_tokens * sizeof(int32_t));
std::memcpy(batch_all.n_seq_id, meta.n_seq_id, meta.n_tokens * sizeof(int32_t));
}
if (meta.seq_id != nullptr) {
batch_all.seq_id = (llama_seq_id **) malloc(cparams.n_ctx * sizeof(llama_seq_id *));
for (size_t i = 0; i < cparams.n_ctx; ++i) {
batch_all.seq_id = (llama_seq_id **) malloc(meta.n_tokens * sizeof(llama_seq_id *));
for (int32_t i = 0; i < meta.n_tokens; ++i) {
batch_all.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id));
batch_all.seq_id[i][0] = meta.seq_id[i][0];
}
@ -18346,7 +18348,7 @@ static int llama_decode_internal(
meta.logits = batch_all.logits;
meta.all_pos_0 = batch_all.all_pos_0;
meta.all_pos_1 = batch_all.all_pos_1;
llama_send_meta(*lctx.send_socket, &meta);
llama_send_meta(*lctx.send_socket, &meta, server_mode);
}
lctx.sbatch.from_batch(batch_all, n_embd,
@ -23484,8 +23486,9 @@ int32_t llama_encode(
int32_t llama_decode(
struct llama_context * ctx,
struct llama_batch batch) {
return llama_decode_internal(*ctx, batch);
struct llama_batch batch,
bool server_mode) {
return llama_decode_internal(*ctx, batch, server_mode);
}
void llama_synchronize(struct llama_context * ctx) {