mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-06 07:59:05 +00:00
Merge pull request #24 from Lizonghang/lizh_dev
Fix batch decoding and dynamic batching.
This commit is contained in:
commit
c8af1be27e
2 changed files with 125 additions and 54 deletions
|
@ -1424,13 +1424,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
params.defrag_thold = std::stof(value);
|
params.defrag_thold = std::stof(value);
|
||||||
}
|
}
|
||||||
).set_env("LLAMA_ARG_DEFRAG_THOLD"));
|
).set_env("LLAMA_ARG_DEFRAG_THOLD"));
|
||||||
// add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
// {"-np", "--parallel"}, "N",
|
{"-np", "--parallel"}, "N",
|
||||||
// format("number of parallel sequences to decode (default: %d)", params.n_parallel),
|
format("number of parallel sequences to decode (default: %d)", params.n_parallel),
|
||||||
// [](gpt_params & params, int value) {
|
[](gpt_params & params, int value) {
|
||||||
// params.n_parallel = value;
|
params.n_parallel = value;
|
||||||
// }
|
}
|
||||||
// ).set_env("LLAMA_ARG_N_PARALLEL"));
|
).set_env("LLAMA_ARG_N_PARALLEL"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"-ns", "--sequences"}, "N",
|
{"-ns", "--sequences"}, "N",
|
||||||
format("number of sequences to decode (default: %d)", params.n_sequences),
|
format("number of sequences to decode (default: %d)", params.n_sequences),
|
||||||
|
|
163
src/llama.cpp
163
src/llama.cpp
|
@ -2782,7 +2782,6 @@ struct llama_layer {
|
||||||
// but has more metadata about sequences
|
// but has more metadata about sequences
|
||||||
struct llama_ubatch {
|
struct llama_ubatch {
|
||||||
bool equal_seqs;
|
bool equal_seqs;
|
||||||
// TODO: whole_seqs for embeddings?
|
|
||||||
|
|
||||||
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
||||||
uint32_t n_seq_tokens; // tokens per sequence
|
uint32_t n_seq_tokens; // tokens per sequence
|
||||||
|
@ -2796,6 +2795,9 @@ struct llama_ubatch {
|
||||||
int32_t * n_seq_id; // [n_seqs]
|
int32_t * n_seq_id; // [n_seqs]
|
||||||
llama_seq_id ** seq_id; // [n_seqs]
|
llama_seq_id ** seq_id; // [n_seqs]
|
||||||
int8_t * output; // [n_tokens]
|
int8_t * output; // [n_tokens]
|
||||||
|
|
||||||
|
bool activate_input;
|
||||||
|
bool activate_output;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_kv_cell {
|
struct llama_kv_cell {
|
||||||
|
@ -3040,7 +3042,7 @@ struct llama_sbatch {
|
||||||
ubatch_token.resize(!has_embd ? n_ubatch : 0);
|
ubatch_token.resize(!has_embd ? n_ubatch : 0);
|
||||||
ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
|
ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
|
||||||
ubatch_backend_embd.resize(n_embd * n_tokens);
|
ubatch_backend_embd.resize(n_embd * n_tokens);
|
||||||
ubatch_out_embd.resize(n_embd);
|
ubatch_out_embd.resize(n_embd * n_tokens);
|
||||||
ubatch_pos.resize(n_ubatch);
|
ubatch_pos.resize(n_ubatch);
|
||||||
ubatch_n_seq_id.resize(n_ubatch);
|
ubatch_n_seq_id.resize(n_ubatch);
|
||||||
ubatch_seq_id.resize(n_ubatch);
|
ubatch_seq_id.resize(n_ubatch);
|
||||||
|
@ -3058,6 +3060,8 @@ struct llama_sbatch {
|
||||||
/*n_seq_id =*/ ubatch_n_seq_id.data(),
|
/*n_seq_id =*/ ubatch_n_seq_id.data(),
|
||||||
/*seq_id =*/ ubatch_seq_id.data(),
|
/*seq_id =*/ ubatch_seq_id.data(),
|
||||||
/*output =*/ ubatch_output.data(),
|
/*output =*/ ubatch_output.data(),
|
||||||
|
/*activate_input =*/ true,
|
||||||
|
/*activate_output =*/ false,
|
||||||
};
|
};
|
||||||
return ubatch;
|
return ubatch;
|
||||||
}
|
}
|
||||||
|
@ -11104,7 +11108,6 @@ struct llm_build_context {
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
// skip computing output for unused tokens
|
// skip computing output for unused tokens
|
||||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||||
n_tokens = n_outputs;
|
|
||||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||||
}
|
}
|
||||||
|
@ -16978,7 +16981,7 @@ static std::vector<struct ggml_cgraph *> llama_build_graph(
|
||||||
|
|
||||||
llm.init();
|
llm.init();
|
||||||
|
|
||||||
GGML_ASSERT((model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_QWEN2) && "this model is currently not supported");
|
GGML_ASSERT((model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_QWEN2) && "this model is currently not supported.\n");
|
||||||
|
|
||||||
switch (model.arch) {
|
switch (model.arch) {
|
||||||
case LLM_ARCH_LLAMA:
|
case LLM_ARCH_LLAMA:
|
||||||
|
@ -17261,31 +17264,32 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
|
||||||
const auto & cparams = lctx.cparams;
|
const auto & cparams = lctx.cparams;
|
||||||
const auto & kv_self = lctx.kv_self;
|
const auto & kv_self = lctx.kv_self;
|
||||||
|
|
||||||
if (batch.token) {
|
if (batch.activate_input) {
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
|
if (batch.token) {
|
||||||
}
|
const int64_t size_ = n_tokens * ggml_element_size(lctx.inp_tokens);
|
||||||
|
ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, size_);
|
||||||
|
}
|
||||||
|
|
||||||
if (batch.embd) {
|
if (batch.embd) {
|
||||||
const int64_t n_embd = hparams.n_embd;
|
const int64_t n_embd = hparams.n_embd;
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t size_ = n_tokens * n_embd * ggml_element_size(lctx.inp_embd);
|
||||||
|
ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, size_);
|
||||||
ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
|
}
|
||||||
}
|
} else if (batch.activate_output) {
|
||||||
|
if (batch.out_embd && lctx.out_embd) {
|
||||||
if (batch.backend_embd && lctx.backend_embd && lctx.backend_embd->data != nullptr) {
|
const int64_t n_embd = lctx.out_embd->ne[0];
|
||||||
|
const int64_t n_output = lctx.out_embd->ne[1];
|
||||||
|
const int64_t size_ = n_output * n_embd * ggml_element_size(lctx.out_embd);
|
||||||
|
ggml_backend_tensor_set(lctx.out_embd, batch.out_embd, 0, size_);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(batch.backend_embd && lctx.backend_embd && lctx.backend_embd->data != nullptr);
|
||||||
const int64_t n_embd = lctx.backend_embd->ne[0];
|
const int64_t n_embd = lctx.backend_embd->ne[0];
|
||||||
const int64_t n_tokens = lctx.backend_embd->ne[1];
|
const int64_t n_tokens = lctx.backend_embd->ne[1];
|
||||||
|
const int64_t size_ = n_tokens * n_embd * ggml_element_size(lctx.backend_embd);
|
||||||
ggml_backend_tensor_set(lctx.backend_embd, batch.backend_embd, 0, n_tokens*n_embd*ggml_element_size(lctx.backend_embd));
|
ggml_backend_tensor_set(lctx.backend_embd, batch.backend_embd, 0, size_);
|
||||||
}
|
|
||||||
|
|
||||||
if (batch.out_embd && lctx.out_embd) {
|
|
||||||
const int64_t n_embd = lctx.out_embd->ne[0];
|
|
||||||
const int64_t n_output = lctx.out_embd->ne[1];
|
|
||||||
|
|
||||||
ggml_backend_tensor_set(lctx.out_embd, batch.out_embd, 0, n_output*n_embd*ggml_element_size(lctx.out_embd));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (batch.pos && lctx.inp_pos) {
|
if (batch.pos && lctx.inp_pos) {
|
||||||
|
@ -17801,11 +17805,15 @@ struct input_tensors {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct sync_meta {
|
struct sync_meta {
|
||||||
int32_t n_tokens = 0;
|
int32_t n_tokens = 0;
|
||||||
llama_pos * pos = nullptr;
|
llama_pos * pos = nullptr;
|
||||||
|
int32_t * n_seq_id = nullptr;
|
||||||
|
llama_seq_id ** seq_id = nullptr;
|
||||||
|
int8_t * logits = nullptr;
|
||||||
|
|
||||||
llama_pos all_pos_0;
|
llama_pos all_pos_0;
|
||||||
llama_pos all_pos_1;
|
llama_pos all_pos_1;
|
||||||
uint32_t n_ctx = 0;
|
uint32_t n_ctx = 0;
|
||||||
|
|
||||||
// signal to clear the kv cache
|
// signal to clear the kv cache
|
||||||
bool clear_kv_cache = false;
|
bool clear_kv_cache = false;
|
||||||
|
@ -17852,6 +17860,29 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
||||||
send_msgs.emplace_back(meta->pos, meta->n_ctx * sizeof(llama_pos));
|
send_msgs.emplace_back(meta->pos, meta->n_ctx * sizeof(llama_pos));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (meta->n_seq_id != nullptr) {
|
||||||
|
GGML_ASSERT(meta->n_ctx > 0);
|
||||||
|
send_msgs.emplace_back("n_seq_id", strlen("n_seq_id"));
|
||||||
|
send_msgs.emplace_back(meta->n_seq_id, meta->n_ctx * sizeof(int32_t));
|
||||||
|
|
||||||
|
// here we assume only a single seq_id per token is needed
|
||||||
|
// pack all single seq_id values into a contiguous array
|
||||||
|
llama_seq_id * all_seq_ids = (llama_seq_id *) malloc(meta->n_ctx * sizeof(llama_seq_id));
|
||||||
|
for (uint32_t i = 0; i < meta->n_ctx; ++i) {
|
||||||
|
all_seq_ids[i] = meta->seq_id[i][0];
|
||||||
|
}
|
||||||
|
|
||||||
|
send_msgs.emplace_back("seq_id", strlen("seq_id"));
|
||||||
|
send_msgs.emplace_back(all_seq_ids, meta->n_ctx * sizeof(llama_seq_id));
|
||||||
|
free(all_seq_ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (meta->logits != nullptr) {
|
||||||
|
GGML_ASSERT(meta->n_tokens > 0);
|
||||||
|
send_msgs.emplace_back("logits", strlen("logits"));
|
||||||
|
send_msgs.emplace_back(meta->logits, meta->n_tokens * sizeof(int8_t));
|
||||||
|
}
|
||||||
|
|
||||||
send_msgs.emplace_back("all_pos_0", strlen("all_pos_0"));
|
send_msgs.emplace_back("all_pos_0", strlen("all_pos_0"));
|
||||||
send_msgs.emplace_back(&(meta->all_pos_0), sizeof(meta->all_pos_0));
|
send_msgs.emplace_back(&(meta->all_pos_0), sizeof(meta->all_pos_0));
|
||||||
|
|
||||||
|
@ -17931,6 +17962,31 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
||||||
std::memcpy(meta->pos, data_msg.data(), meta->n_ctx * sizeof(llama_pos));
|
std::memcpy(meta->pos, data_msg.data(), meta->n_ctx * sizeof(llama_pos));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (key == "n_seq_id") {
|
||||||
|
GGML_ASSERT(meta->n_ctx > 0);
|
||||||
|
GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(int32_t));
|
||||||
|
meta->n_seq_id = (int32_t *) malloc(meta->n_ctx * sizeof(int32_t));
|
||||||
|
std::memcpy(meta->n_seq_id, data_msg.data(), meta->n_ctx * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (key == "seq_id") {
|
||||||
|
GGML_ASSERT(meta->n_ctx > 0);
|
||||||
|
GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(llama_seq_id));
|
||||||
|
const llama_seq_id * all_seq_ids = (llama_seq_id *) data_msg.data();
|
||||||
|
meta->seq_id = (llama_seq_id **) malloc(meta->n_ctx * sizeof(llama_seq_id *));
|
||||||
|
for (uint32_t i = 0; i < meta->n_ctx; ++i) {
|
||||||
|
meta->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id));
|
||||||
|
meta->seq_id[i][0] = all_seq_ids[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (key == "logits") {
|
||||||
|
GGML_ASSERT(meta->n_tokens > 0);
|
||||||
|
GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int8_t));
|
||||||
|
meta->logits = (int8_t *) malloc(meta->n_tokens * sizeof(int8_t));
|
||||||
|
std::memcpy(meta->logits, data_msg.data(), meta->n_tokens * sizeof(int8_t));
|
||||||
|
}
|
||||||
|
|
||||||
if (key == "all_pos_0") {
|
if (key == "all_pos_0") {
|
||||||
GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_0));
|
GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_0));
|
||||||
std::memcpy(&(meta->all_pos_0), data_msg.data(), sizeof(meta->all_pos_0));
|
std::memcpy(&(meta->all_pos_0), data_msg.data(), sizeof(meta->all_pos_0));
|
||||||
|
@ -17971,6 +18027,7 @@ static void llama_recv_tensors(zmq::socket_t & socket, struct llama_ubatch * uba
|
||||||
std::vector<zmq::message_t> recv_msgs;
|
std::vector<zmq::message_t> recv_msgs;
|
||||||
if (!zmq::recv_multipart(socket, std::back_inserter(recv_msgs))) {
|
if (!zmq::recv_multipart(socket, std::back_inserter(recv_msgs))) {
|
||||||
LLAMA_LOG_INFO("Failed to receive tensor data.\n");
|
LLAMA_LOG_INFO("Failed to receive tensor data.\n");
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < recv_msgs.size(); i += 3) {
|
for (size_t i = 0; i < recv_msgs.size(); i += 3) {
|
||||||
|
@ -18218,6 +18275,21 @@ static int llama_decode_internal(
|
||||||
batch_all.pos = (llama_pos *) malloc(cparams.n_ctx * sizeof(llama_pos));
|
batch_all.pos = (llama_pos *) malloc(cparams.n_ctx * sizeof(llama_pos));
|
||||||
std::memcpy(batch_all.pos, meta.pos, cparams.n_ctx * sizeof(llama_pos));
|
std::memcpy(batch_all.pos, meta.pos, cparams.n_ctx * sizeof(llama_pos));
|
||||||
}
|
}
|
||||||
|
if (meta.n_seq_id != nullptr) {
|
||||||
|
batch_all.n_seq_id = (int32_t *) malloc(cparams.n_ctx * sizeof(int32_t));
|
||||||
|
std::memcpy(batch_all.n_seq_id, meta.n_seq_id, cparams.n_ctx * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
if (meta.seq_id != nullptr) {
|
||||||
|
batch_all.seq_id = (llama_seq_id **) malloc(cparams.n_ctx * sizeof(llama_seq_id *));
|
||||||
|
for (size_t i = 0; i < cparams.n_ctx; ++i) {
|
||||||
|
batch_all.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id));
|
||||||
|
batch_all.seq_id[i][0] = meta.seq_id[i][0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (meta.logits != nullptr) {
|
||||||
|
batch_all.logits = (int8_t *) malloc(meta.n_tokens * sizeof(int8_t));
|
||||||
|
std::memcpy(batch_all.logits, meta.logits, meta.n_tokens * sizeof(int8_t));
|
||||||
|
}
|
||||||
batch_all.all_pos_0 = meta.all_pos_0;
|
batch_all.all_pos_0 = meta.all_pos_0;
|
||||||
batch_all.all_pos_1 = meta.all_pos_1;
|
batch_all.all_pos_1 = meta.all_pos_1;
|
||||||
}
|
}
|
||||||
|
@ -18266,6 +18338,9 @@ static int llama_decode_internal(
|
||||||
if (!is_last_dev) {
|
if (!is_last_dev) {
|
||||||
meta.n_tokens = batch_all.n_tokens;
|
meta.n_tokens = batch_all.n_tokens;
|
||||||
meta.pos = batch_all.pos;
|
meta.pos = batch_all.pos;
|
||||||
|
meta.n_seq_id = batch_all.n_seq_id;
|
||||||
|
meta.seq_id = batch_all.seq_id;
|
||||||
|
meta.logits = batch_all.logits;
|
||||||
meta.all_pos_0 = batch_all.all_pos_0;
|
meta.all_pos_0 = batch_all.all_pos_0;
|
||||||
meta.all_pos_1 = batch_all.all_pos_1;
|
meta.all_pos_1 = batch_all.all_pos_1;
|
||||||
llama_send_meta(*lctx.send_socket, &meta);
|
llama_send_meta(*lctx.send_socket, &meta);
|
||||||
|
@ -18281,8 +18356,7 @@ static int llama_decode_internal(
|
||||||
return -2;
|
return -2;
|
||||||
};
|
};
|
||||||
|
|
||||||
{ // assume there is only one batch
|
while (lctx.sbatch.n_tokens > 0) { // handle multiple batches
|
||||||
// while (lctx.sbatch.n_tokens > 0) { // handle multiple batches
|
|
||||||
llama_ubatch ubatch;
|
llama_ubatch ubatch;
|
||||||
if (kv_self.recurrent) {
|
if (kv_self.recurrent) {
|
||||||
if (embd_pooled) {
|
if (embd_pooled) {
|
||||||
|
@ -18300,26 +18374,19 @@ static int llama_decode_internal(
|
||||||
|
|
||||||
// count the outputs in this u_batch
|
// count the outputs in this u_batch
|
||||||
int32_t n_outputs_new = 0;
|
int32_t n_outputs_new = 0;
|
||||||
|
if (n_outputs == n_tokens_all) {
|
||||||
if (my_rank == 0) {
|
n_outputs_new = n_tokens;
|
||||||
if (n_outputs == n_tokens_all) {
|
|
||||||
n_outputs_new = n_tokens;
|
|
||||||
} else {
|
|
||||||
GGML_ASSERT(ubatch.output);
|
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
|
||||||
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
n_outputs_new += 1;
|
GGML_ASSERT(ubatch.output);
|
||||||
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
|
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// needs to happen before the graph is built
|
// needs to happen before the graph is built
|
||||||
lctx.n_outputs = n_outputs_new;
|
lctx.n_outputs = n_outputs_new;
|
||||||
|
|
||||||
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
||||||
ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
|
ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
|
||||||
|
|
||||||
GGML_ASSERT(n_threads > 0);
|
GGML_ASSERT(n_threads > 0);
|
||||||
|
|
||||||
// non-causal masks do not use the KV cache
|
// non-causal masks do not use the KV cache
|
||||||
|
@ -18394,11 +18461,11 @@ static int llama_decode_internal(
|
||||||
GGML_ASSERT(my_rank == 0 || n_world > 1);
|
GGML_ASSERT(my_rank == 0 || n_world > 1);
|
||||||
|
|
||||||
for (size_t i = 0; i < (size_t)gf.size(); ++i) {
|
for (size_t i = 0; i < (size_t)gf.size(); ++i) {
|
||||||
|
const bool is_out_embd = my_rank == 0 && i == (size_t)gf.size() - 1;
|
||||||
sub_gf = gf[i];
|
sub_gf = gf[i];
|
||||||
|
|
||||||
// receive data from other nodes
|
// receive data from other nodes
|
||||||
if (n_world > 1 && !(my_rank == 0 && i == 0) && !(my_rank == 0 && is_last_l)) {
|
if (n_world > 1 && !(my_rank == 0 && i == 0) && !(my_rank == 0 && is_last_l)) {
|
||||||
const bool is_out_embd = my_rank == 0 && i == (size_t)gf.size() - 1;
|
|
||||||
llama_recv_tensors(*lctx.recv_socket, &ubatch, is_out_embd);
|
llama_recv_tensors(*lctx.recv_socket, &ubatch, is_out_embd);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18407,6 +18474,10 @@ static int llama_decode_internal(
|
||||||
ggml_backend_sched_synchronize(lctx.sched[i - 1]);
|
ggml_backend_sched_synchronize(lctx.sched[i - 1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ubatch.activate_input = (my_rank == 0 && i == 0);
|
||||||
|
ubatch.activate_output = (my_rank == 0 && is_out_embd);
|
||||||
|
GGML_ASSERT(!(ubatch.activate_input && ubatch.activate_output));
|
||||||
|
|
||||||
llama_set_inputs(lctx, ubatch);
|
llama_set_inputs(lctx, ubatch);
|
||||||
|
|
||||||
{ // compute graph
|
{ // compute graph
|
||||||
|
@ -18442,13 +18513,13 @@ static int llama_decode_internal(
|
||||||
GGML_ASSERT(buf_size <= ggml_nbytes(sub_gf_out));
|
GGML_ASSERT(buf_size <= ggml_nbytes(sub_gf_out));
|
||||||
GGML_ASSERT(backend != nullptr);
|
GGML_ASSERT(backend != nullptr);
|
||||||
ggml_backend_tensor_get_async(backend, sub_gf_out, embd_buf, 0, buf_size);
|
ggml_backend_tensor_get_async(backend, sub_gf_out, embd_buf, 0, buf_size);
|
||||||
|
ggml_backend_sched_synchronize(lctx.sched[i]);
|
||||||
|
|
||||||
// send the result to the next node or the master
|
// send the result to the next node or the master
|
||||||
if (!(n_world == 1 || (my_rank == 0 && is_last_l))) {
|
if (!(n_world == 1 || (my_rank == 0 && is_last_l))) {
|
||||||
struct input_tensors tensors = {sub_gf_out, lctx.inp_pos};
|
struct input_tensors tensors = {sub_gf_out, lctx.inp_pos};
|
||||||
const bool is_to_master = my_rank != 0 && is_last_l;
|
const bool is_to_master = my_rank != 0 && is_last_l;
|
||||||
zmq::socket_t * s = is_to_master ? lctx.master_socket : lctx.send_socket;
|
zmq::socket_t * s = is_to_master ? lctx.master_socket : lctx.send_socket;
|
||||||
ggml_backend_sched_synchronize(lctx.sched[i]);
|
|
||||||
llama_send_tensors(*s, &ubatch, &tensors);
|
llama_send_tensors(*s, &ubatch, &tensors);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19038,7 +19109,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||||
uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
|
uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
|
||||||
llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
|
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, true, false};
|
||||||
std::vector<ggml_cgraph *> gf = llama_build_graph(lctx, ubatch, true);
|
std::vector<ggml_cgraph *> gf = llama_build_graph(lctx, ubatch, true);
|
||||||
GGML_ASSERT(lctx.sched.size() == gf.size());
|
GGML_ASSERT(lctx.sched.size() == gf.size());
|
||||||
|
|
||||||
|
@ -21115,7 +21186,7 @@ void * llama_context_setup_backend(
|
||||||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||||
llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
|
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, true, false};
|
||||||
std::vector<ggml_cgraph *> gf = llama_build_graph(*ctx, ubatch, true);
|
std::vector<ggml_cgraph *> gf = llama_build_graph(*ctx, ubatch, true);
|
||||||
|
|
||||||
GGML_ASSERT(gf.size() <= MAX_SCHEDULERS && "Number of subgraphs exceeds the maximum number of schedulers\n");
|
GGML_ASSERT(gf.size() <= MAX_SCHEDULERS && "Number of subgraphs exceeds the maximum number of schedulers\n");
|
||||||
|
|
Loading…
Add table
Reference in a new issue