Merge pull request #24 from Lizonghang/lizh_dev

Fix batch decoding and dynamic batching.
This commit is contained in:
Zonghang Li 2025-06-07 01:02:10 +04:00 committed by GitHub
commit c8af1be27e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 125 additions and 54 deletions

View file

@ -1424,13 +1424,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
params.defrag_thold = std::stof(value);
}
).set_env("LLAMA_ARG_DEFRAG_THOLD"));
// add_opt(llama_arg(
// {"-np", "--parallel"}, "N",
// format("number of parallel sequences to decode (default: %d)", params.n_parallel),
// [](gpt_params & params, int value) {
// params.n_parallel = value;
// }
// ).set_env("LLAMA_ARG_N_PARALLEL"));
add_opt(llama_arg(
{"-np", "--parallel"}, "N",
format("number of parallel sequences to decode (default: %d)", params.n_parallel),
[](gpt_params & params, int value) {
params.n_parallel = value;
}
).set_env("LLAMA_ARG_N_PARALLEL"));
add_opt(llama_arg(
{"-ns", "--sequences"}, "N",
format("number of sequences to decode (default: %d)", params.n_sequences),

View file

@ -2782,7 +2782,6 @@ struct llama_layer {
// but has more metadata about sequences
struct llama_ubatch {
bool equal_seqs;
// TODO: whole_seqs for embeddings?
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
uint32_t n_seq_tokens; // tokens per sequence
@ -2796,6 +2795,9 @@ struct llama_ubatch {
int32_t * n_seq_id; // [n_seqs]
llama_seq_id ** seq_id; // [n_seqs]
int8_t * output; // [n_tokens]
bool activate_input;
bool activate_output;
};
struct llama_kv_cell {
@ -3040,7 +3042,7 @@ struct llama_sbatch {
ubatch_token.resize(!has_embd ? n_ubatch : 0);
ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
ubatch_backend_embd.resize(n_embd * n_tokens);
ubatch_out_embd.resize(n_embd);
ubatch_out_embd.resize(n_embd * n_tokens);
ubatch_pos.resize(n_ubatch);
ubatch_n_seq_id.resize(n_ubatch);
ubatch_seq_id.resize(n_ubatch);
@ -3058,6 +3060,8 @@ struct llama_sbatch {
/*n_seq_id =*/ ubatch_n_seq_id.data(),
/*seq_id =*/ ubatch_seq_id.data(),
/*output =*/ ubatch_output.data(),
/*activate_input =*/ true,
/*activate_output =*/ false,
};
return ubatch;
}
@ -11104,7 +11108,6 @@ struct llm_build_context {
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
n_tokens = n_outputs;
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
@ -16978,7 +16981,7 @@ static std::vector<struct ggml_cgraph *> llama_build_graph(
llm.init();
GGML_ASSERT((model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_QWEN2) && "this model is currently not supported");
GGML_ASSERT((model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_QWEN2) && "this model is currently not supported.\n");
switch (model.arch) {
case LLM_ARCH_LLAMA:
@ -17261,31 +17264,32 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
const auto & cparams = lctx.cparams;
const auto & kv_self = lctx.kv_self;
if (batch.token) {
if (batch.activate_input) {
const int64_t n_tokens = batch.n_tokens;
ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
}
if (batch.token) {
const int64_t size_ = n_tokens * ggml_element_size(lctx.inp_tokens);
ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, size_);
}
if (batch.embd) {
const int64_t n_embd = hparams.n_embd;
const int64_t n_tokens = batch.n_tokens;
ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
}
if (batch.backend_embd && lctx.backend_embd && lctx.backend_embd->data != nullptr) {
if (batch.embd) {
const int64_t n_embd = hparams.n_embd;
const int64_t size_ = n_tokens * n_embd * ggml_element_size(lctx.inp_embd);
ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, size_);
}
} else if (batch.activate_output) {
if (batch.out_embd && lctx.out_embd) {
const int64_t n_embd = lctx.out_embd->ne[0];
const int64_t n_output = lctx.out_embd->ne[1];
const int64_t size_ = n_output * n_embd * ggml_element_size(lctx.out_embd);
ggml_backend_tensor_set(lctx.out_embd, batch.out_embd, 0, size_);
}
} else {
GGML_ASSERT(batch.backend_embd && lctx.backend_embd && lctx.backend_embd->data != nullptr);
const int64_t n_embd = lctx.backend_embd->ne[0];
const int64_t n_tokens = lctx.backend_embd->ne[1];
ggml_backend_tensor_set(lctx.backend_embd, batch.backend_embd, 0, n_tokens*n_embd*ggml_element_size(lctx.backend_embd));
}
if (batch.out_embd && lctx.out_embd) {
const int64_t n_embd = lctx.out_embd->ne[0];
const int64_t n_output = lctx.out_embd->ne[1];
ggml_backend_tensor_set(lctx.out_embd, batch.out_embd, 0, n_output*n_embd*ggml_element_size(lctx.out_embd));
const int64_t size_ = n_tokens * n_embd * ggml_element_size(lctx.backend_embd);
ggml_backend_tensor_set(lctx.backend_embd, batch.backend_embd, 0, size_);
}
if (batch.pos && lctx.inp_pos) {
@ -17801,12 +17805,16 @@ struct input_tensors {
};
struct sync_meta {
int32_t n_tokens = 0;
llama_pos * pos = nullptr;
int32_t n_tokens = 0;
llama_pos * pos = nullptr;
int32_t * n_seq_id = nullptr;
llama_seq_id ** seq_id = nullptr;
int8_t * logits = nullptr;
llama_pos all_pos_0;
llama_pos all_pos_1;
uint32_t n_ctx = 0;
uint32_t n_ctx = 0;
// signal to clear the kv cache
bool clear_kv_cache = false;
@ -17852,6 +17860,29 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
send_msgs.emplace_back(meta->pos, meta->n_ctx * sizeof(llama_pos));
}
if (meta->n_seq_id != nullptr) {
GGML_ASSERT(meta->n_ctx > 0);
send_msgs.emplace_back("n_seq_id", strlen("n_seq_id"));
send_msgs.emplace_back(meta->n_seq_id, meta->n_ctx * sizeof(int32_t));
// here we assume only a single seq_id per token is needed
// pack all single seq_id values into a contiguous array
llama_seq_id * all_seq_ids = (llama_seq_id *) malloc(meta->n_ctx * sizeof(llama_seq_id));
for (uint32_t i = 0; i < meta->n_ctx; ++i) {
all_seq_ids[i] = meta->seq_id[i][0];
}
send_msgs.emplace_back("seq_id", strlen("seq_id"));
send_msgs.emplace_back(all_seq_ids, meta->n_ctx * sizeof(llama_seq_id));
free(all_seq_ids);
}
if (meta->logits != nullptr) {
GGML_ASSERT(meta->n_tokens > 0);
send_msgs.emplace_back("logits", strlen("logits"));
send_msgs.emplace_back(meta->logits, meta->n_tokens * sizeof(int8_t));
}
send_msgs.emplace_back("all_pos_0", strlen("all_pos_0"));
send_msgs.emplace_back(&(meta->all_pos_0), sizeof(meta->all_pos_0));
@ -17931,6 +17962,31 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
std::memcpy(meta->pos, data_msg.data(), meta->n_ctx * sizeof(llama_pos));
}
if (key == "n_seq_id") {
GGML_ASSERT(meta->n_ctx > 0);
GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(int32_t));
meta->n_seq_id = (int32_t *) malloc(meta->n_ctx * sizeof(int32_t));
std::memcpy(meta->n_seq_id, data_msg.data(), meta->n_ctx * sizeof(int32_t));
}
if (key == "seq_id") {
GGML_ASSERT(meta->n_ctx > 0);
GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(llama_seq_id));
const llama_seq_id * all_seq_ids = (llama_seq_id *) data_msg.data();
meta->seq_id = (llama_seq_id **) malloc(meta->n_ctx * sizeof(llama_seq_id *));
for (uint32_t i = 0; i < meta->n_ctx; ++i) {
meta->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id));
meta->seq_id[i][0] = all_seq_ids[i];
}
}
if (key == "logits") {
GGML_ASSERT(meta->n_tokens > 0);
GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int8_t));
meta->logits = (int8_t *) malloc(meta->n_tokens * sizeof(int8_t));
std::memcpy(meta->logits, data_msg.data(), meta->n_tokens * sizeof(int8_t));
}
if (key == "all_pos_0") {
GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_0));
std::memcpy(&(meta->all_pos_0), data_msg.data(), sizeof(meta->all_pos_0));
@ -17971,6 +18027,7 @@ static void llama_recv_tensors(zmq::socket_t & socket, struct llama_ubatch * uba
std::vector<zmq::message_t> recv_msgs;
if (!zmq::recv_multipart(socket, std::back_inserter(recv_msgs))) {
LLAMA_LOG_INFO("Failed to receive tensor data.\n");
return;
}
for (size_t i = 0; i < recv_msgs.size(); i += 3) {
@ -18218,6 +18275,21 @@ static int llama_decode_internal(
batch_all.pos = (llama_pos *) malloc(cparams.n_ctx * sizeof(llama_pos));
std::memcpy(batch_all.pos, meta.pos, cparams.n_ctx * sizeof(llama_pos));
}
if (meta.n_seq_id != nullptr) {
batch_all.n_seq_id = (int32_t *) malloc(cparams.n_ctx * sizeof(int32_t));
std::memcpy(batch_all.n_seq_id, meta.n_seq_id, cparams.n_ctx * sizeof(int32_t));
}
if (meta.seq_id != nullptr) {
batch_all.seq_id = (llama_seq_id **) malloc(cparams.n_ctx * sizeof(llama_seq_id *));
for (size_t i = 0; i < cparams.n_ctx; ++i) {
batch_all.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id));
batch_all.seq_id[i][0] = meta.seq_id[i][0];
}
}
if (meta.logits != nullptr) {
batch_all.logits = (int8_t *) malloc(meta.n_tokens * sizeof(int8_t));
std::memcpy(batch_all.logits, meta.logits, meta.n_tokens * sizeof(int8_t));
}
batch_all.all_pos_0 = meta.all_pos_0;
batch_all.all_pos_1 = meta.all_pos_1;
}
@ -18266,6 +18338,9 @@ static int llama_decode_internal(
if (!is_last_dev) {
meta.n_tokens = batch_all.n_tokens;
meta.pos = batch_all.pos;
meta.n_seq_id = batch_all.n_seq_id;
meta.seq_id = batch_all.seq_id;
meta.logits = batch_all.logits;
meta.all_pos_0 = batch_all.all_pos_0;
meta.all_pos_1 = batch_all.all_pos_1;
llama_send_meta(*lctx.send_socket, &meta);
@ -18281,8 +18356,7 @@ static int llama_decode_internal(
return -2;
};
{ // assume there is only one batch
// while (lctx.sbatch.n_tokens > 0) { // handle multiple batches
while (lctx.sbatch.n_tokens > 0) { // handle multiple batches
llama_ubatch ubatch;
if (kv_self.recurrent) {
if (embd_pooled) {
@ -18300,26 +18374,19 @@ static int llama_decode_internal(
// count the outputs in this u_batch
int32_t n_outputs_new = 0;
if (my_rank == 0) {
if (n_outputs == n_tokens_all) {
n_outputs_new = n_tokens;
} else {
GGML_ASSERT(ubatch.output);
for (uint32_t i = 0; i < n_tokens; i++) {
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
}
}
if (n_outputs == n_tokens_all) {
n_outputs_new = n_tokens;
} else {
n_outputs_new += 1;
GGML_ASSERT(ubatch.output);
for (uint32_t i = 0; i < n_tokens; i++) {
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
}
}
// needs to happen before the graph is built
lctx.n_outputs = n_outputs_new;
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
GGML_ASSERT(n_threads > 0);
// non-causal masks do not use the KV cache
@ -18394,11 +18461,11 @@ static int llama_decode_internal(
GGML_ASSERT(my_rank == 0 || n_world > 1);
for (size_t i = 0; i < (size_t)gf.size(); ++i) {
const bool is_out_embd = my_rank == 0 && i == (size_t)gf.size() - 1;
sub_gf = gf[i];
// receive data from other nodes
if (n_world > 1 && !(my_rank == 0 && i == 0) && !(my_rank == 0 && is_last_l)) {
const bool is_out_embd = my_rank == 0 && i == (size_t)gf.size() - 1;
llama_recv_tensors(*lctx.recv_socket, &ubatch, is_out_embd);
}
@ -18407,6 +18474,10 @@ static int llama_decode_internal(
ggml_backend_sched_synchronize(lctx.sched[i - 1]);
}
ubatch.activate_input = (my_rank == 0 && i == 0);
ubatch.activate_output = (my_rank == 0 && is_out_embd);
GGML_ASSERT(!(ubatch.activate_input && ubatch.activate_output));
llama_set_inputs(lctx, ubatch);
{ // compute graph
@ -18442,13 +18513,13 @@ static int llama_decode_internal(
GGML_ASSERT(buf_size <= ggml_nbytes(sub_gf_out));
GGML_ASSERT(backend != nullptr);
ggml_backend_tensor_get_async(backend, sub_gf_out, embd_buf, 0, buf_size);
ggml_backend_sched_synchronize(lctx.sched[i]);
// send the result to the next node or the master
if (!(n_world == 1 || (my_rank == 0 && is_last_l))) {
struct input_tensors tensors = {sub_gf_out, lctx.inp_pos};
const bool is_to_master = my_rank != 0 && is_last_l;
zmq::socket_t * s = is_to_master ? lctx.master_socket : lctx.send_socket;
ggml_backend_sched_synchronize(lctx.sched[i]);
llama_send_tensors(*s, &ubatch, &tensors);
}
@ -19038,7 +19109,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, true, false};
std::vector<ggml_cgraph *> gf = llama_build_graph(lctx, ubatch, true);
GGML_ASSERT(lctx.sched.size() == gf.size());
@ -21115,7 +21186,7 @@ void * llama_context_setup_backend(
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, true, false};
std::vector<ggml_cgraph *> gf = llama_build_graph(*ctx, ubatch, true);
GGML_ASSERT(gf.size() <= MAX_SCHEDULERS && "Number of subgraphs exceeds the maximum number of schedulers\n");