mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-05 22:19:03 +00:00
Merge pull request #24 from Lizonghang/lizh_dev
Fix batch decoding and dynamic batching.
This commit is contained in:
commit
c8af1be27e
2 changed files with 125 additions and 54 deletions
|
@ -1424,13 +1424,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
|||
params.defrag_thold = std::stof(value);
|
||||
}
|
||||
).set_env("LLAMA_ARG_DEFRAG_THOLD"));
|
||||
// add_opt(llama_arg(
|
||||
// {"-np", "--parallel"}, "N",
|
||||
// format("number of parallel sequences to decode (default: %d)", params.n_parallel),
|
||||
// [](gpt_params & params, int value) {
|
||||
// params.n_parallel = value;
|
||||
// }
|
||||
// ).set_env("LLAMA_ARG_N_PARALLEL"));
|
||||
add_opt(llama_arg(
|
||||
{"-np", "--parallel"}, "N",
|
||||
format("number of parallel sequences to decode (default: %d)", params.n_parallel),
|
||||
[](gpt_params & params, int value) {
|
||||
params.n_parallel = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_N_PARALLEL"));
|
||||
add_opt(llama_arg(
|
||||
{"-ns", "--sequences"}, "N",
|
||||
format("number of sequences to decode (default: %d)", params.n_sequences),
|
||||
|
|
165
src/llama.cpp
165
src/llama.cpp
|
@ -2782,7 +2782,6 @@ struct llama_layer {
|
|||
// but has more metadata about sequences
|
||||
struct llama_ubatch {
|
||||
bool equal_seqs;
|
||||
// TODO: whole_seqs for embeddings?
|
||||
|
||||
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
||||
uint32_t n_seq_tokens; // tokens per sequence
|
||||
|
@ -2796,6 +2795,9 @@ struct llama_ubatch {
|
|||
int32_t * n_seq_id; // [n_seqs]
|
||||
llama_seq_id ** seq_id; // [n_seqs]
|
||||
int8_t * output; // [n_tokens]
|
||||
|
||||
bool activate_input;
|
||||
bool activate_output;
|
||||
};
|
||||
|
||||
struct llama_kv_cell {
|
||||
|
@ -3040,7 +3042,7 @@ struct llama_sbatch {
|
|||
ubatch_token.resize(!has_embd ? n_ubatch : 0);
|
||||
ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
|
||||
ubatch_backend_embd.resize(n_embd * n_tokens);
|
||||
ubatch_out_embd.resize(n_embd);
|
||||
ubatch_out_embd.resize(n_embd * n_tokens);
|
||||
ubatch_pos.resize(n_ubatch);
|
||||
ubatch_n_seq_id.resize(n_ubatch);
|
||||
ubatch_seq_id.resize(n_ubatch);
|
||||
|
@ -3058,6 +3060,8 @@ struct llama_sbatch {
|
|||
/*n_seq_id =*/ ubatch_n_seq_id.data(),
|
||||
/*seq_id =*/ ubatch_seq_id.data(),
|
||||
/*output =*/ ubatch_output.data(),
|
||||
/*activate_input =*/ true,
|
||||
/*activate_output =*/ false,
|
||||
};
|
||||
return ubatch;
|
||||
}
|
||||
|
@ -11104,7 +11108,6 @@ struct llm_build_context {
|
|||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
n_tokens = n_outputs;
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
@ -16978,7 +16981,7 @@ static std::vector<struct ggml_cgraph *> llama_build_graph(
|
|||
|
||||
llm.init();
|
||||
|
||||
GGML_ASSERT((model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_QWEN2) && "this model is currently not supported");
|
||||
GGML_ASSERT((model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_QWEN2) && "this model is currently not supported.\n");
|
||||
|
||||
switch (model.arch) {
|
||||
case LLM_ARCH_LLAMA:
|
||||
|
@ -17261,31 +17264,32 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
|
|||
const auto & cparams = lctx.cparams;
|
||||
const auto & kv_self = lctx.kv_self;
|
||||
|
||||
if (batch.token) {
|
||||
if (batch.activate_input) {
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
|
||||
}
|
||||
if (batch.token) {
|
||||
const int64_t size_ = n_tokens * ggml_element_size(lctx.inp_tokens);
|
||||
ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, size_);
|
||||
}
|
||||
|
||||
if (batch.embd) {
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
|
||||
}
|
||||
|
||||
if (batch.backend_embd && lctx.backend_embd && lctx.backend_embd->data != nullptr) {
|
||||
if (batch.embd) {
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t size_ = n_tokens * n_embd * ggml_element_size(lctx.inp_embd);
|
||||
ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, size_);
|
||||
}
|
||||
} else if (batch.activate_output) {
|
||||
if (batch.out_embd && lctx.out_embd) {
|
||||
const int64_t n_embd = lctx.out_embd->ne[0];
|
||||
const int64_t n_output = lctx.out_embd->ne[1];
|
||||
const int64_t size_ = n_output * n_embd * ggml_element_size(lctx.out_embd);
|
||||
ggml_backend_tensor_set(lctx.out_embd, batch.out_embd, 0, size_);
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(batch.backend_embd && lctx.backend_embd && lctx.backend_embd->data != nullptr);
|
||||
const int64_t n_embd = lctx.backend_embd->ne[0];
|
||||
const int64_t n_tokens = lctx.backend_embd->ne[1];
|
||||
|
||||
ggml_backend_tensor_set(lctx.backend_embd, batch.backend_embd, 0, n_tokens*n_embd*ggml_element_size(lctx.backend_embd));
|
||||
}
|
||||
|
||||
if (batch.out_embd && lctx.out_embd) {
|
||||
const int64_t n_embd = lctx.out_embd->ne[0];
|
||||
const int64_t n_output = lctx.out_embd->ne[1];
|
||||
|
||||
ggml_backend_tensor_set(lctx.out_embd, batch.out_embd, 0, n_output*n_embd*ggml_element_size(lctx.out_embd));
|
||||
const int64_t size_ = n_tokens * n_embd * ggml_element_size(lctx.backend_embd);
|
||||
ggml_backend_tensor_set(lctx.backend_embd, batch.backend_embd, 0, size_);
|
||||
}
|
||||
|
||||
if (batch.pos && lctx.inp_pos) {
|
||||
|
@ -17801,12 +17805,16 @@ struct input_tensors {
|
|||
};
|
||||
|
||||
struct sync_meta {
|
||||
int32_t n_tokens = 0;
|
||||
llama_pos * pos = nullptr;
|
||||
int32_t n_tokens = 0;
|
||||
llama_pos * pos = nullptr;
|
||||
int32_t * n_seq_id = nullptr;
|
||||
llama_seq_id ** seq_id = nullptr;
|
||||
int8_t * logits = nullptr;
|
||||
|
||||
llama_pos all_pos_0;
|
||||
llama_pos all_pos_1;
|
||||
uint32_t n_ctx = 0;
|
||||
|
||||
uint32_t n_ctx = 0;
|
||||
|
||||
// signal to clear the kv cache
|
||||
bool clear_kv_cache = false;
|
||||
|
||||
|
@ -17852,6 +17860,29 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
|||
send_msgs.emplace_back(meta->pos, meta->n_ctx * sizeof(llama_pos));
|
||||
}
|
||||
|
||||
if (meta->n_seq_id != nullptr) {
|
||||
GGML_ASSERT(meta->n_ctx > 0);
|
||||
send_msgs.emplace_back("n_seq_id", strlen("n_seq_id"));
|
||||
send_msgs.emplace_back(meta->n_seq_id, meta->n_ctx * sizeof(int32_t));
|
||||
|
||||
// here we assume only a single seq_id per token is needed
|
||||
// pack all single seq_id values into a contiguous array
|
||||
llama_seq_id * all_seq_ids = (llama_seq_id *) malloc(meta->n_ctx * sizeof(llama_seq_id));
|
||||
for (uint32_t i = 0; i < meta->n_ctx; ++i) {
|
||||
all_seq_ids[i] = meta->seq_id[i][0];
|
||||
}
|
||||
|
||||
send_msgs.emplace_back("seq_id", strlen("seq_id"));
|
||||
send_msgs.emplace_back(all_seq_ids, meta->n_ctx * sizeof(llama_seq_id));
|
||||
free(all_seq_ids);
|
||||
}
|
||||
|
||||
if (meta->logits != nullptr) {
|
||||
GGML_ASSERT(meta->n_tokens > 0);
|
||||
send_msgs.emplace_back("logits", strlen("logits"));
|
||||
send_msgs.emplace_back(meta->logits, meta->n_tokens * sizeof(int8_t));
|
||||
}
|
||||
|
||||
send_msgs.emplace_back("all_pos_0", strlen("all_pos_0"));
|
||||
send_msgs.emplace_back(&(meta->all_pos_0), sizeof(meta->all_pos_0));
|
||||
|
||||
|
@ -17931,6 +17962,31 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
|||
std::memcpy(meta->pos, data_msg.data(), meta->n_ctx * sizeof(llama_pos));
|
||||
}
|
||||
|
||||
if (key == "n_seq_id") {
|
||||
GGML_ASSERT(meta->n_ctx > 0);
|
||||
GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(int32_t));
|
||||
meta->n_seq_id = (int32_t *) malloc(meta->n_ctx * sizeof(int32_t));
|
||||
std::memcpy(meta->n_seq_id, data_msg.data(), meta->n_ctx * sizeof(int32_t));
|
||||
}
|
||||
|
||||
if (key == "seq_id") {
|
||||
GGML_ASSERT(meta->n_ctx > 0);
|
||||
GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(llama_seq_id));
|
||||
const llama_seq_id * all_seq_ids = (llama_seq_id *) data_msg.data();
|
||||
meta->seq_id = (llama_seq_id **) malloc(meta->n_ctx * sizeof(llama_seq_id *));
|
||||
for (uint32_t i = 0; i < meta->n_ctx; ++i) {
|
||||
meta->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id));
|
||||
meta->seq_id[i][0] = all_seq_ids[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (key == "logits") {
|
||||
GGML_ASSERT(meta->n_tokens > 0);
|
||||
GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int8_t));
|
||||
meta->logits = (int8_t *) malloc(meta->n_tokens * sizeof(int8_t));
|
||||
std::memcpy(meta->logits, data_msg.data(), meta->n_tokens * sizeof(int8_t));
|
||||
}
|
||||
|
||||
if (key == "all_pos_0") {
|
||||
GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_0));
|
||||
std::memcpy(&(meta->all_pos_0), data_msg.data(), sizeof(meta->all_pos_0));
|
||||
|
@ -17971,6 +18027,7 @@ static void llama_recv_tensors(zmq::socket_t & socket, struct llama_ubatch * uba
|
|||
std::vector<zmq::message_t> recv_msgs;
|
||||
if (!zmq::recv_multipart(socket, std::back_inserter(recv_msgs))) {
|
||||
LLAMA_LOG_INFO("Failed to receive tensor data.\n");
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < recv_msgs.size(); i += 3) {
|
||||
|
@ -18218,6 +18275,21 @@ static int llama_decode_internal(
|
|||
batch_all.pos = (llama_pos *) malloc(cparams.n_ctx * sizeof(llama_pos));
|
||||
std::memcpy(batch_all.pos, meta.pos, cparams.n_ctx * sizeof(llama_pos));
|
||||
}
|
||||
if (meta.n_seq_id != nullptr) {
|
||||
batch_all.n_seq_id = (int32_t *) malloc(cparams.n_ctx * sizeof(int32_t));
|
||||
std::memcpy(batch_all.n_seq_id, meta.n_seq_id, cparams.n_ctx * sizeof(int32_t));
|
||||
}
|
||||
if (meta.seq_id != nullptr) {
|
||||
batch_all.seq_id = (llama_seq_id **) malloc(cparams.n_ctx * sizeof(llama_seq_id *));
|
||||
for (size_t i = 0; i < cparams.n_ctx; ++i) {
|
||||
batch_all.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id));
|
||||
batch_all.seq_id[i][0] = meta.seq_id[i][0];
|
||||
}
|
||||
}
|
||||
if (meta.logits != nullptr) {
|
||||
batch_all.logits = (int8_t *) malloc(meta.n_tokens * sizeof(int8_t));
|
||||
std::memcpy(batch_all.logits, meta.logits, meta.n_tokens * sizeof(int8_t));
|
||||
}
|
||||
batch_all.all_pos_0 = meta.all_pos_0;
|
||||
batch_all.all_pos_1 = meta.all_pos_1;
|
||||
}
|
||||
|
@ -18266,6 +18338,9 @@ static int llama_decode_internal(
|
|||
if (!is_last_dev) {
|
||||
meta.n_tokens = batch_all.n_tokens;
|
||||
meta.pos = batch_all.pos;
|
||||
meta.n_seq_id = batch_all.n_seq_id;
|
||||
meta.seq_id = batch_all.seq_id;
|
||||
meta.logits = batch_all.logits;
|
||||
meta.all_pos_0 = batch_all.all_pos_0;
|
||||
meta.all_pos_1 = batch_all.all_pos_1;
|
||||
llama_send_meta(*lctx.send_socket, &meta);
|
||||
|
@ -18281,8 +18356,7 @@ static int llama_decode_internal(
|
|||
return -2;
|
||||
};
|
||||
|
||||
{ // assume there is only one batch
|
||||
// while (lctx.sbatch.n_tokens > 0) { // handle multiple batches
|
||||
while (lctx.sbatch.n_tokens > 0) { // handle multiple batches
|
||||
llama_ubatch ubatch;
|
||||
if (kv_self.recurrent) {
|
||||
if (embd_pooled) {
|
||||
|
@ -18300,26 +18374,19 @@ static int llama_decode_internal(
|
|||
|
||||
// count the outputs in this u_batch
|
||||
int32_t n_outputs_new = 0;
|
||||
|
||||
if (my_rank == 0) {
|
||||
if (n_outputs == n_tokens_all) {
|
||||
n_outputs_new = n_tokens;
|
||||
} else {
|
||||
GGML_ASSERT(ubatch.output);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
||||
}
|
||||
}
|
||||
if (n_outputs == n_tokens_all) {
|
||||
n_outputs_new = n_tokens;
|
||||
} else {
|
||||
n_outputs_new += 1;
|
||||
GGML_ASSERT(ubatch.output);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
||||
}
|
||||
}
|
||||
|
||||
// needs to happen before the graph is built
|
||||
lctx.n_outputs = n_outputs_new;
|
||||
|
||||
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
||||
ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
|
||||
|
||||
GGML_ASSERT(n_threads > 0);
|
||||
|
||||
// non-causal masks do not use the KV cache
|
||||
|
@ -18394,11 +18461,11 @@ static int llama_decode_internal(
|
|||
GGML_ASSERT(my_rank == 0 || n_world > 1);
|
||||
|
||||
for (size_t i = 0; i < (size_t)gf.size(); ++i) {
|
||||
const bool is_out_embd = my_rank == 0 && i == (size_t)gf.size() - 1;
|
||||
sub_gf = gf[i];
|
||||
|
||||
// receive data from other nodes
|
||||
if (n_world > 1 && !(my_rank == 0 && i == 0) && !(my_rank == 0 && is_last_l)) {
|
||||
const bool is_out_embd = my_rank == 0 && i == (size_t)gf.size() - 1;
|
||||
llama_recv_tensors(*lctx.recv_socket, &ubatch, is_out_embd);
|
||||
}
|
||||
|
||||
|
@ -18407,6 +18474,10 @@ static int llama_decode_internal(
|
|||
ggml_backend_sched_synchronize(lctx.sched[i - 1]);
|
||||
}
|
||||
|
||||
ubatch.activate_input = (my_rank == 0 && i == 0);
|
||||
ubatch.activate_output = (my_rank == 0 && is_out_embd);
|
||||
GGML_ASSERT(!(ubatch.activate_input && ubatch.activate_output));
|
||||
|
||||
llama_set_inputs(lctx, ubatch);
|
||||
|
||||
{ // compute graph
|
||||
|
@ -18442,13 +18513,13 @@ static int llama_decode_internal(
|
|||
GGML_ASSERT(buf_size <= ggml_nbytes(sub_gf_out));
|
||||
GGML_ASSERT(backend != nullptr);
|
||||
ggml_backend_tensor_get_async(backend, sub_gf_out, embd_buf, 0, buf_size);
|
||||
ggml_backend_sched_synchronize(lctx.sched[i]);
|
||||
|
||||
// send the result to the next node or the master
|
||||
if (!(n_world == 1 || (my_rank == 0 && is_last_l))) {
|
||||
struct input_tensors tensors = {sub_gf_out, lctx.inp_pos};
|
||||
const bool is_to_master = my_rank != 0 && is_last_l;
|
||||
zmq::socket_t * s = is_to_master ? lctx.master_socket : lctx.send_socket;
|
||||
ggml_backend_sched_synchronize(lctx.sched[i]);
|
||||
llama_send_tensors(*s, &ubatch, &tensors);
|
||||
}
|
||||
|
||||
|
@ -19038,7 +19109,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
|||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||
uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
|
||||
llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, true, false};
|
||||
std::vector<ggml_cgraph *> gf = llama_build_graph(lctx, ubatch, true);
|
||||
GGML_ASSERT(lctx.sched.size() == gf.size());
|
||||
|
||||
|
@ -21115,7 +21186,7 @@ void * llama_context_setup_backend(
|
|||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, true, false};
|
||||
std::vector<ggml_cgraph *> gf = llama_build_graph(*ctx, ubatch, true);
|
||||
|
||||
GGML_ASSERT(gf.size() <= MAX_SCHEDULERS && "Number of subgraphs exceeds the maximum number of schedulers\n");
|
||||
|
|
Loading…
Add table
Reference in a new issue