mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-05 22:59:06 +00:00
Restored support for calculating perplexity in standalone test models
This commit is contained in:
parent
ef1e10101e
commit
e38f13ba17
7 changed files with 45028 additions and 68 deletions
3
Makefile
3
Makefile
|
@ -2,7 +2,8 @@
|
|||
BUILD_TARGETS = \
|
||||
llama-server \
|
||||
llama-cli \
|
||||
profile-tool
|
||||
profile-tool \
|
||||
llama-perplexity
|
||||
|
||||
# BUILD_TARGETS = \
|
||||
# libllava.a \
|
||||
|
|
|
@ -586,6 +586,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
int n_outputs = 0;
|
||||
|
||||
batch.n_tokens = 0;
|
||||
|
||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||
int seq_start = batch_start + seq*n_ctx;
|
||||
|
||||
|
@ -599,6 +600,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
|
||||
for (int k = 0; k < batch_size; ++k) {
|
||||
const int idx = seq*n_ctx + k;
|
||||
|
||||
batch.token [idx] = tokens[seq_start + k];
|
||||
batch.pos [idx] = j*n_batch + k;
|
||||
batch.n_seq_id[idx] = 1;
|
||||
|
|
BIN
scripts/wikitext-2-raw-v1.zip
Normal file
BIN
scripts/wikitext-2-raw-v1.zip
Normal file
Binary file not shown.
4358
scripts/wikitext-2-raw/wiki.test.raw
Normal file
4358
scripts/wikitext-2-raw/wiki.test.raw
Normal file
File diff suppressed because it is too large
Load diff
36718
scripts/wikitext-2-raw/wiki.train.raw
Normal file
36718
scripts/wikitext-2-raw/wiki.train.raw
Normal file
File diff suppressed because it is too large
Load diff
3760
scripts/wikitext-2-raw/wiki.valid.raw
Normal file
3760
scripts/wikitext-2-raw/wiki.valid.raw
Normal file
File diff suppressed because it is too large
Load diff
255
src/llama.cpp
255
src/llama.cpp
|
@ -2855,6 +2855,16 @@ struct llama_kv_cache {
|
|||
return size;
|
||||
}
|
||||
|
||||
llama_pos get_pos_max() const{
|
||||
llama_pos pos_max = -1;
|
||||
|
||||
for (const auto & cell : cells) {
|
||||
pos_max = std::max(pos_max, cell.pos);
|
||||
}
|
||||
|
||||
return pos_max;
|
||||
}
|
||||
|
||||
~llama_kv_cache() {
|
||||
for (struct ggml_context * ctx : ctxs) {
|
||||
ggml_free(ctx);
|
||||
|
@ -3039,8 +3049,11 @@ struct llama_sbatch {
|
|||
}
|
||||
ubatch_token.resize(!has_embd ? n_ubatch : 0);
|
||||
ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
|
||||
ubatch_backend_embd.resize(n_embd * n_tokens);
|
||||
ubatch_out_embd.resize(n_embd);
|
||||
|
||||
// TODO: just a guess and test, need to be removed(from tao)
|
||||
ubatch_backend_embd.resize(n_embd * n_tokens * 3);
|
||||
ubatch_out_embd.resize(n_embd * n_tokens);
|
||||
|
||||
ubatch_pos.resize(n_ubatch);
|
||||
ubatch_n_seq_id.resize(n_ubatch);
|
||||
ubatch_seq_id.resize(n_ubatch);
|
||||
|
@ -3156,7 +3169,7 @@ struct llama_sbatch {
|
|||
} else {
|
||||
// simple split
|
||||
ubatch.output = batch->logits + seq.offset;
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
|
||||
}
|
||||
}
|
||||
|
@ -3189,7 +3202,7 @@ struct llama_sbatch {
|
|||
llama_sbatch_seq & s = seq[0];
|
||||
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
||||
GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
|
||||
add_seq_to_ubatch(ubatch, s, length);
|
||||
add_seq_to_ubatch(ubatch, s, length);
|
||||
}
|
||||
return ubatch;
|
||||
}
|
||||
|
@ -3326,6 +3339,51 @@ struct llama_sbatch {
|
|||
}
|
||||
};
|
||||
|
||||
struct llama_batch_allocr {
|
||||
struct llama_batch batch;
|
||||
|
||||
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id;
|
||||
std::vector<int8_t> logits;
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
// optionally fulfill the batch returned by llama_batch_get_one
|
||||
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0){
|
||||
batch = in_batch;
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
if (!batch.pos) {
|
||||
pos.resize(batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
pos[i] = i + p0;
|
||||
}
|
||||
batch.pos = pos.data();
|
||||
}
|
||||
if (!batch.n_seq_id) {
|
||||
n_seq_id.resize(batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
n_seq_id[i] = seq_id_0.size();
|
||||
}
|
||||
batch.n_seq_id = n_seq_id.data();
|
||||
}
|
||||
if (!batch.seq_id) {
|
||||
seq_id.resize(batch.n_tokens + 1);
|
||||
seq_id[batch.n_tokens] = NULL;
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
seq_id[i] = seq_id_0.data();
|
||||
}
|
||||
batch.seq_id = seq_id.data();
|
||||
}
|
||||
if (!batch.logits) {
|
||||
logits.resize(batch.n_tokens);
|
||||
logits[logits.size() - 1] = true;
|
||||
batch.logits = logits.data();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct llama_context {
|
||||
llama_context(const llama_model & model)
|
||||
: model(model)
|
||||
|
@ -3381,17 +3439,17 @@ struct llama_context {
|
|||
size_t logits_size = 0; // capacity (of floats) for logits
|
||||
float * logits = nullptr;
|
||||
|
||||
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||
size_t embd_size = 0; // capacity (of floats) for embeddings
|
||||
float * embd = nullptr;
|
||||
|
||||
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
||||
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
|
||||
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
||||
|
||||
bool logits_all = false;
|
||||
|
||||
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||
size_t embd_size = 0; // capacity (of floats) for embeddings
|
||||
float * embd = nullptr;
|
||||
|
||||
// sequence embeddings output (map of [n_embd] vectors)
|
||||
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
||||
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
||||
|
@ -4194,6 +4252,7 @@ static bool llama_kv_cache_find_slot(
|
|||
|
||||
for (int32_t j = 0; j < batch.n_seq_id[s]; j++) {
|
||||
cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]);
|
||||
// printf("batch.seq_id[%u][%d] = %u\n", s, j, batch.seq_id[s][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -10991,6 +11050,7 @@ struct llm_build_context {
|
|||
|
||||
// create a vector to hold sub-graphs
|
||||
std::vector<struct ggml_cgraph *> sub_gfs;
|
||||
|
||||
struct ggml_cgraph * sub_gf = nullptr;
|
||||
struct ggml_tensor * cur = nullptr;
|
||||
struct ggml_tensor * inpL = nullptr;
|
||||
|
@ -17363,7 +17423,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
|
|||
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
float f;
|
||||
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
||||
if (!kv_self.cells[i].has_seq_id(seq_id)
|
||||
|| (cparams.causal_attn && kv_self.cells[i].pos > pos)) {
|
||||
f = -INFINITY;
|
||||
} else {
|
||||
if (hparams.use_alibi) {
|
||||
|
@ -17681,8 +17742,13 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
|
|||
const auto n_embd = hparams.n_embd;
|
||||
|
||||
// TODO: use a per-batch flag for logits presence instead
|
||||
const bool has_logits = !cparams.embeddings;
|
||||
const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
||||
bool has_logits = !cparams.embeddings;
|
||||
bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
||||
|
||||
if (lctx.model.arch == LLM_ARCH_T5) {
|
||||
has_logits = true;
|
||||
has_embd = true;
|
||||
}
|
||||
|
||||
const size_t logits_size = has_logits ? n_vocab * n_outputs_max : 0;
|
||||
const size_t embd_size = has_embd ? n_embd * n_outputs_max : 0;
|
||||
|
@ -18143,21 +18209,33 @@ static void manage_graph_tensors(struct ggml_cgraph * cgraph, int advice, bool f
|
|||
//
|
||||
static int llama_decode_internal(
|
||||
llama_context & lctx,
|
||||
llama_batch batch_all) { // TODO: rename back to batch
|
||||
llama_batch & batch_all) { // TODO: rename back to batch
|
||||
|
||||
// llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.get_pos_max() + 1);
|
||||
|
||||
// // const llama_batch & batch_all = batch_allocr.batch;
|
||||
// llama_batch & batch_all = batch_allocr.batch;
|
||||
|
||||
lctx.is_encoding = false;
|
||||
|
||||
const auto & model = lctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & n_vocab = hparams.n_vocab;
|
||||
|
||||
const uint32_t n_world = cparams.n_world;
|
||||
const uint32_t my_rank = cparams.rank;
|
||||
|
||||
lctx.is_encoding = false;
|
||||
|
||||
const uint32_t n_tokens_all = batch_all.n_tokens;
|
||||
const int64_t n_embd = hparams.n_embd; // used for reserving embeddings space size
|
||||
|
||||
if (my_rank != 0) {
|
||||
batch_all.token = nullptr;
|
||||
}
|
||||
|
||||
GGML_ASSERT(!(my_rank == 0 && n_tokens_all == 0) && "n_tokens == 0 on master node");
|
||||
|
||||
GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT
|
||||
if (batch_all.token) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
if (batch_all.token[i] < 0 || (uint32_t)batch_all.token[i] >= model.vocab.n_vocab) {
|
||||
|
@ -18178,31 +18256,28 @@ static int llama_decode_internal(
|
|||
|
||||
auto & kv_self = lctx.kv_self;
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_vocab = hparams.n_vocab;
|
||||
|
||||
uint32_t n_outputs = 0;
|
||||
uint32_t n_outputs_prev = 0;
|
||||
|
||||
const auto n_ubatch = cparams.n_ubatch;
|
||||
// const auto n_ubatch = cparams.n_ubatch;
|
||||
|
||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
lctx.embd_seq.clear();
|
||||
|
||||
int64_t n_outputs = 0; // Specify the number of tokens to output
|
||||
|
||||
// count outputs
|
||||
if (batch_all.logits && !embd_pooled) {
|
||||
if (batch_all.logits && !embd_pooled) { // Specifies which positions need to be output
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
n_outputs += batch_all.logits[i] != 0;
|
||||
}
|
||||
} else if (lctx.logits_all || embd_pooled) {
|
||||
n_outputs = n_tokens_all;
|
||||
} else if (embd_pooled) { // all input tokens need to be output
|
||||
n_outputs = n_tokens_all;
|
||||
} else {
|
||||
// keep last output only
|
||||
n_outputs = 1;
|
||||
}
|
||||
|
||||
// TODO:needs to be encapsulated into a function
|
||||
sync_meta meta;
|
||||
meta.n_ctx = cparams.n_ctx;
|
||||
bool is_last_dev = (my_rank == n_world - 1);
|
||||
|
@ -18281,50 +18356,49 @@ static int llama_decode_internal(
|
|||
return -2;
|
||||
};
|
||||
|
||||
{ // assume there is only one batch
|
||||
// while (lctx.sbatch.n_tokens > 0) { // handle multiple batches
|
||||
uint32_t n_outputs_prev = 0;
|
||||
|
||||
// { // assume there is only one batch
|
||||
while (lctx.sbatch.n_tokens > 0) { // handle multiple batches
|
||||
|
||||
llama_ubatch ubatch;
|
||||
if (kv_self.recurrent) {
|
||||
if (embd_pooled) {
|
||||
// Pooled embeddings cannot be split across ubatches (yet)
|
||||
ubatch = lctx.sbatch.split_seq(n_ubatch);
|
||||
ubatch = lctx.sbatch.split_seq(cparams.n_ubatch);
|
||||
} else {
|
||||
// recurrent model architectures are easier to implement
|
||||
// with equal-length sequences
|
||||
ubatch = lctx.sbatch.split_equal(n_ubatch);
|
||||
ubatch = lctx.sbatch.split_equal(cparams.n_ubatch);
|
||||
}
|
||||
} else {
|
||||
ubatch = lctx.sbatch.split_simple(n_ubatch);
|
||||
ubatch = lctx.sbatch.split_simple(cparams.n_ubatch);
|
||||
}
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
const uint32_t n_tokens = ubatch.n_tokens; // number of tokens in this ubatch
|
||||
|
||||
llama_kv_cache_update(&lctx);
|
||||
|
||||
// count the outputs in this u_batch
|
||||
int32_t n_outputs_new = 0;
|
||||
|
||||
if (my_rank == 0) {
|
||||
if (n_outputs == n_tokens_all) {
|
||||
// n_outputs is the number of tokens to output in input batch
|
||||
if (n_outputs == n_tokens_all) { // all completed tokens have logits
|
||||
n_outputs_new = n_tokens;
|
||||
} else {
|
||||
GGML_ASSERT(ubatch.output);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
||||
n_outputs_new += (int32_t) (ubatch.output[i] != 0); // ubatch.output = 0, means no logits
|
||||
}
|
||||
}
|
||||
} else {
|
||||
n_outputs_new += 1;
|
||||
|
||||
// needs to happen before the graph is built
|
||||
lctx.n_outputs = n_outputs_new;
|
||||
|
||||
}
|
||||
|
||||
// needs to happen before the graph is built
|
||||
lctx.n_outputs = n_outputs_new;
|
||||
|
||||
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
||||
ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
|
||||
|
||||
GGML_ASSERT(n_threads > 0);
|
||||
|
||||
// non-causal masks do not use the KV cache
|
||||
if (hparams.causal_attn) {
|
||||
llama_kv_cache_update(&lctx);
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
|
@ -18333,6 +18407,7 @@ static int llama_decode_internal(
|
|||
}
|
||||
|
||||
if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
|
||||
LLAMA_LOG_WARN("%s: failed to find KV Cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
@ -18351,15 +18426,15 @@ static int llama_decode_internal(
|
|||
}
|
||||
|
||||
std::vector<ggml_cgraph *> gf = llama_build_graph(lctx, ubatch, false);
|
||||
GGML_ASSERT(lctx.sched.size() == gf.size());
|
||||
|
||||
// the output is always the last tensor in the graph
|
||||
struct ggml_tensor * res = nullptr;
|
||||
struct ggml_tensor * embd = nullptr;
|
||||
struct ggml_tensor * sub_gf_out = nullptr;
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
if (my_rank == 0) {
|
||||
res = ggml_graph_node(gf.back(), -1);
|
||||
embd = ggml_graph_node(gf.back(), -2);
|
||||
}
|
||||
|
||||
if (lctx.n_outputs == 0) {
|
||||
|
@ -18380,7 +18455,6 @@ static int llama_decode_internal(
|
|||
embd = nullptr; // do not extract embeddings when not needed
|
||||
}
|
||||
|
||||
GGML_ASSERT(lctx.sched.size() == gf.size());
|
||||
for (size_t i = 0; i < (size_t)lctx.sched.size(); ++i) {
|
||||
ggml_backend_sched_alloc_graph(lctx.sched[i], gf[i]);
|
||||
}
|
||||
|
@ -18409,7 +18483,12 @@ static int llama_decode_internal(
|
|||
|
||||
llama_set_inputs(lctx, ubatch);
|
||||
|
||||
{ // compute graph
|
||||
{
|
||||
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
||||
ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
|
||||
GGML_ASSERT(n_threads > 0);
|
||||
|
||||
// compute graph
|
||||
timer(llama_graph_compute);
|
||||
llama_graph_compute(lctx, sub_gf, lctx.sched[i], n_threads, threadpool);
|
||||
}
|
||||
|
@ -18427,23 +18506,25 @@ static int llama_decode_internal(
|
|||
cur_l = std::atoi(layer_str);
|
||||
is_last_l = (cur_l == static_cast<int>(n_layer) - 1);
|
||||
}
|
||||
|
||||
|
||||
size_t n_elem = sub_gf_out->ne[0] * sub_gf_out->ne[1];
|
||||
size_t buf_size = n_elem * sizeof(float);
|
||||
|
||||
float * embd_buf;
|
||||
if (n_world == 1 || (my_rank == 0 && is_last_l)) {
|
||||
embd_buf = is_last_l ? ubatch.out_embd : ubatch.backend_embd;
|
||||
embd_buf = is_last_l ? ubatch.out_embd : ubatch.backend_embd; // backend_emd: Intermediate results calculated by backend devices
|
||||
} else {
|
||||
embd_buf = ubatch.backend_embd;
|
||||
}
|
||||
GGML_ASSERT(embd_buf != nullptr);
|
||||
|
||||
// copy device data to cpu memory
|
||||
size_t buf_size = sub_gf_out->ne[0] * sub_gf_out->ne[1] * sizeof(float);
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(lctx.sched[i], sub_gf_out);
|
||||
GGML_ASSERT(buf_size <= ggml_nbytes(sub_gf_out));
|
||||
GGML_ASSERT(backend != nullptr);
|
||||
ggml_backend_tensor_get_async(backend, sub_gf_out, embd_buf, 0, buf_size);
|
||||
// copy device data to cpu memory
|
||||
ggml_backend_tensor_get_async(backend, sub_gf_out, embd_buf, 0, buf_size);
|
||||
|
||||
// send the result to the next node or the master
|
||||
// send the result to the next node or the master(only for distributed environment)
|
||||
if (!(n_world == 1 || (my_rank == 0 && is_last_l))) {
|
||||
struct input_tensors tensors = {sub_gf_out, lctx.inp_pos};
|
||||
const bool is_to_master = my_rank != 0 && is_last_l;
|
||||
|
@ -18475,18 +18556,18 @@ static int llama_decode_internal(
|
|||
}
|
||||
|
||||
// extract logits
|
||||
if (res) {
|
||||
if (my_rank == 0 && (res && lctx.n_outputs > 0)) {
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched.back(), res);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(lctx.logits != nullptr);
|
||||
|
||||
float * logits_out = lctx.logits + n_outputs_prev * n_vocab;
|
||||
const int32_t n_outputs_new = lctx.n_outputs;
|
||||
float * logits_out = lctx.logits + n_outputs_prev * n_vocab; // n_outputs_prev * n_vocab is offset relative to the start of the logits buffer
|
||||
|
||||
if (n_outputs_new) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs_new) * n_vocab <= (int64_t) lctx.logits_size);
|
||||
ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new * n_vocab * sizeof(float));
|
||||
if (lctx.n_outputs) {
|
||||
GGML_ASSERT( n_outputs_prev + lctx.n_outputs <= n_outputs);
|
||||
GGML_ASSERT((n_outputs_prev + lctx.n_outputs) * n_vocab <= (int64_t) lctx.logits_size);
|
||||
|
||||
ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, lctx.n_outputs * n_vocab * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -18547,17 +18628,26 @@ static int llama_decode_internal(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
n_outputs_prev += lctx.n_outputs;
|
||||
}
|
||||
|
||||
lctx.n_outputs = n_outputs;
|
||||
|
||||
if (my_rank == 0) {
|
||||
// set output mappings
|
||||
bool sorted_output = true;
|
||||
GGML_ASSERT(lctx.sbatch.out_ids.size() == n_outputs);
|
||||
|
||||
for (size_t i = 0; i < n_outputs; ++i) {
|
||||
size_t out_id = lctx.sbatch.out_ids[i];
|
||||
lctx.output_ids[out_id] = i;
|
||||
auto & out_ids = lctx.sbatch.out_ids;
|
||||
auto & output_ids = lctx.output_ids;
|
||||
|
||||
// printf("DEBUG:out_ids.size() = %zu\n", out_ids.size());
|
||||
// printf("DEBUG:output_ids.size() = %zu\n", output_ids.size());
|
||||
GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
|
||||
|
||||
for (int64_t i = 0; i < n_outputs; ++i) {
|
||||
int64_t out_id = out_ids[i];
|
||||
output_ids[out_id] = i; // If the output is unordered, the value here is irrelevant
|
||||
if (out_id != i) {
|
||||
sorted_output = false;
|
||||
}
|
||||
|
@ -18565,10 +18655,41 @@ static int llama_decode_internal(
|
|||
|
||||
if (sorted_output) {
|
||||
lctx.sbatch.out_ids.clear();
|
||||
} else{
|
||||
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
||||
|
||||
// TODO: is there something more efficient which also minimizes swaps?
|
||||
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
||||
for (int32_t i = 0; i < n_outputs - 1; ++i) {
|
||||
int32_t j_min = i;
|
||||
for (int32_t j = i + 1; j < n_outputs; ++j) {
|
||||
if (out_ids[j] < out_ids[j_min]) {
|
||||
j_min = j;
|
||||
}
|
||||
}
|
||||
if (j_min == i) { continue; }
|
||||
std::swap(out_ids[i], out_ids[j_min]);
|
||||
|
||||
if (lctx.logits_size > 0) {
|
||||
for (uint32_t k = 0; k < n_vocab; k++) {
|
||||
std::swap(lctx.logits[i*n_vocab + k], lctx.logits[j_min*n_vocab + k]);
|
||||
}
|
||||
}
|
||||
|
||||
if (lctx.embd_size > 0) {
|
||||
for (uint32_t k = 0; k < n_embd; k++) {
|
||||
std::swap(lctx.embd[i*n_embd + k], lctx.embd[j_min*n_embd + k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
std::fill(output_ids.begin(), output_ids.end(), -1);
|
||||
for (int32_t i = 0; i < n_outputs; ++i) {
|
||||
output_ids[out_ids[i]] = i;
|
||||
}
|
||||
}
|
||||
|
||||
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
||||
lctx.n_outputs = n_outputs;
|
||||
|
||||
}
|
||||
|
||||
// wait for the computation to finish (automatically done when obtaining the model output)
|
||||
|
|
Loading…
Add table
Reference in a new issue