mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-05 20:29:26 +00:00
Modify the perplexity test to a distributed version
This commit is contained in:
parent
32e1088162
commit
2123879cfe
4 changed files with 314 additions and 149 deletions
|
@ -547,9 +547,9 @@ llama_control_vector_data llama_control_vector_load(const std::vector<llama_cont
|
|||
// Split utils
|
||||
//
|
||||
|
||||
static const char * const LLM_KV_SPLIT_NO = "split.no";
|
||||
static const char * const LLM_KV_SPLIT_COUNT = "split.count";
|
||||
static const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
|
||||
extern const char * const LLM_KV_SPLIT_NO;
|
||||
extern const char * const LLM_KV_SPLIT_COUNT;
|
||||
extern const char * const LLM_KV_SPLIT_TENSORS_COUNT;
|
||||
|
||||
//
|
||||
// YAML utils
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#include "arg.h"
|
||||
#include "llama.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
|
@ -473,6 +473,10 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
|||
}
|
||||
|
||||
static results_perplexity perplexity(llama_context * ctx, const gpt_params & params, const int32_t n_ctx) {
|
||||
uint32_t my_rank = params.rank;
|
||||
uint32_t n_world = params.n_world;
|
||||
bool is_last_dev = (my_rank == n_world - 1);
|
||||
|
||||
if (params.ppl_stride > 0) {
|
||||
return perplexity_v2(ctx, params);
|
||||
}
|
||||
|
@ -485,38 +489,74 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
|
||||
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
|
||||
|
||||
// only last device store logits file
|
||||
std::ofstream logits_stream;
|
||||
if (!params.logits_file.empty()) {
|
||||
logits_stream.open(params.logits_file.c_str(), std::ios::binary);
|
||||
if (!logits_stream.is_open()) {
|
||||
LOG_ERR("%s: failed to open %s for writing\n", __func__, params.logits_file.c_str());
|
||||
return {};
|
||||
if (my_rank == n_world - 1){
|
||||
if (!params.logits_file.empty()) {
|
||||
logits_stream.open(params.logits_file.c_str(), std::ios::binary);
|
||||
if (!logits_stream.is_open()) {
|
||||
LOG_ERR("%s: failed to open %s for writing\n", __func__, params.logits_file.c_str());
|
||||
return {};
|
||||
}
|
||||
LOG_INF("%s: saving all logits to %s\n", __func__, params.logits_file.c_str());
|
||||
logits_stream.write("_logits_", 8);
|
||||
logits_stream.write(reinterpret_cast<const char *>(&n_ctx), sizeof(n_ctx));
|
||||
}
|
||||
LOG_INF("%s: saving all logits to %s\n", __func__, params.logits_file.c_str());
|
||||
logits_stream.write("_logits_", 8);
|
||||
logits_stream.write(reinterpret_cast<const char *>(&n_ctx), sizeof(n_ctx));
|
||||
}
|
||||
|
||||
auto tim1 = std::chrono::high_resolution_clock::now();
|
||||
LOG_INF("%s: tokenizing the input ..\n", __func__);
|
||||
std::vector<llama_token> tokens;
|
||||
size_t tokens_size = 0;
|
||||
|
||||
std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, true);
|
||||
// maybe we need to try other solutions, such as direct communication of tokens between the head and tail nodes
|
||||
if (my_rank == 0 || is_last_dev) {
|
||||
auto tim1 = std::chrono::high_resolution_clock::now();
|
||||
LOG_INF("%s: rank %d tokenizing the input ..\n", __func__, my_rank);
|
||||
|
||||
auto tim2 = std::chrono::high_resolution_clock::now();
|
||||
LOG_INF("%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
|
||||
tokens = ::llama_tokenize(ctx, params.prompt, true);
|
||||
tokens_size = tokens.size();
|
||||
|
||||
if (int(tokens.size()) < 2*n_ctx) {
|
||||
LOG_ERR("%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
|
||||
auto tim2 = std::chrono::high_resolution_clock::now();
|
||||
LOG_INF("%s: rank %d tokenization took %g ms\n", __func__, my_rank, 1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
|
||||
}
|
||||
|
||||
{
|
||||
if (n_world > 1) {
|
||||
sync_meta meta;
|
||||
if (my_rank == 0) {
|
||||
meta.tokens_size = tokens_size;
|
||||
llama_send_meta(ctx, &meta);
|
||||
} else {
|
||||
if (llama_recv_meta(ctx, &meta) == -1) {
|
||||
LOG_ERR("%s: failed to receive tokens_size on rank %d\n", __func__, my_rank);
|
||||
return { {}, -1.0, {}, {} };
|
||||
}
|
||||
if (is_last_dev) {
|
||||
GGML_ASSERT(tokens_size == meta.tokens_size && "Token size mismatch between rank 0 and last rank!");
|
||||
} else {
|
||||
tokens_size = meta.tokens_size;
|
||||
llama_send_meta(ctx, &meta);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
LOG_INF("%s: rank %d synchronized tokens_size = %zu\n", __func__, my_rank, tokens_size);
|
||||
|
||||
if (my_rank == 0) {
|
||||
if (int(tokens.size()) < 2*n_ctx) {
|
||||
LOG_ERR("%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
|
||||
n_ctx);
|
||||
LOG_ERR("%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
|
||||
return {std::move(tokens), 0., {}, {}};
|
||||
LOG_ERR("%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
|
||||
return {std::move(tokens), 0., {}, {}};
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float> logit_history;
|
||||
logit_history.resize(tokens.size());
|
||||
|
||||
std::vector<float> prob_history;
|
||||
prob_history.resize(tokens.size());
|
||||
|
||||
if (is_last_dev) {
|
||||
logit_history.resize(tokens_size);
|
||||
prob_history.resize(tokens_size);
|
||||
}
|
||||
|
||||
const int n_chunk_max = tokens.size() / n_ctx;
|
||||
|
||||
|
@ -537,23 +577,34 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1);
|
||||
|
||||
std::vector<float> logits;
|
||||
if (num_batches > 1) {
|
||||
logits.reserve((size_t)n_ctx * n_vocab);
|
||||
|
||||
if(is_last_dev){
|
||||
if (num_batches > 1) {
|
||||
logits.reserve((size_t)n_ctx * n_vocab);
|
||||
}
|
||||
}
|
||||
|
||||
LOG_INF("%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
|
||||
|
||||
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
|
||||
|
||||
std::vector<uint16_t> log_probs;
|
||||
if (!params.logits_file.empty()) {
|
||||
logits_stream.write((const char *)&n_vocab, sizeof(n_vocab));
|
||||
logits_stream.write((const char *)&n_chunk, sizeof(n_chunk));
|
||||
logits_stream.write((const char *)tokens.data(), n_chunk*n_ctx*sizeof(tokens[0]));
|
||||
std::vector<uint16_t> log_probs; // the log probabilities of logits
|
||||
|
||||
// rank 0 and last rank store log_probs
|
||||
// only rank 0 or last device stores logits/log_probs
|
||||
if (!params.logits_file.empty() && (is_last_dev || my_rank == 0)) {
|
||||
const int nv = 2*((n_vocab + 1)/2) + 4;
|
||||
log_probs.resize(n_ctx * nv);
|
||||
}
|
||||
|
||||
// additional operations only for rank 0 or single device
|
||||
if (my_rank == 0) {
|
||||
// For single device, is_last_dev and my_rank==0 are both true
|
||||
// For multiple devices, only rank 0 will write these headers
|
||||
logits_stream.write((const char *)&n_vocab, sizeof(n_vocab));
|
||||
logits_stream.write((const char *)&n_chunk, sizeof(n_chunk));
|
||||
logits_stream.write((const char *)tokens.data(), n_chunk*n_ctx*sizeof(tokens[0]));
|
||||
}
|
||||
}
|
||||
// We get the logits for all the tokens in the context window (params.n_ctx)
|
||||
// from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
|
||||
// calculate the perplexity over the last half of the window (so the model always has
|
||||
|
@ -576,43 +627,98 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
{
|
||||
if (n_world > 1) {
|
||||
sync_meta clear_meta;
|
||||
clear_meta.clear_kv_cache = true;
|
||||
|
||||
if (my_rank == 0) {
|
||||
llama_send_meta(ctx, &clear_meta);
|
||||
} else {
|
||||
if (llama_recv_meta(ctx, &clear_meta) == -1) {
|
||||
LOG_ERR("Failed to recv clear_kv_cache signal on rank %d\n", my_rank);
|
||||
return {tokens, -1.0, {}, {}};
|
||||
}
|
||||
if (!is_last_dev) {
|
||||
llama_send_meta(ctx, &clear_meta);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// clear the KV cache
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
sync_meta meta;
|
||||
|
||||
for (int j = 0; j < num_batches; ++j) {
|
||||
const int batch_start = start + j * n_batch;
|
||||
const int batch_size = std::min(end - batch_start, n_batch);
|
||||
|
||||
int n_outputs = 0;
|
||||
|
||||
int n_outputs = 0;
|
||||
// only rank 0 constructs the batch, other ranks just receive it
|
||||
if (my_rank == 0){
|
||||
|
||||
batch.n_tokens = 0;
|
||||
batch.n_tokens = 0;
|
||||
|
||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||
int seq_start = batch_start + seq*n_ctx;
|
||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||
int seq_start = batch_start + seq*n_ctx;
|
||||
|
||||
// save original token and restore it after eval
|
||||
const auto token_org = tokens[seq_start];
|
||||
// save original token and restore it after eval
|
||||
const auto token_org = tokens[seq_start];
|
||||
|
||||
// add BOS token for the first batch of each chunk
|
||||
if (add_bos && j == 0) {
|
||||
tokens[seq_start] = llama_token_bos(llama_get_model(ctx));
|
||||
// add BOS token for the first batch of each chunk
|
||||
if (add_bos && j == 0) {
|
||||
tokens[seq_start] = llama_token_bos(llama_get_model(ctx));
|
||||
}
|
||||
|
||||
for (int k = 0; k < batch_size; ++k) {
|
||||
const int idx = seq*n_ctx + k;
|
||||
|
||||
batch.token [idx] = tokens[seq_start + k];
|
||||
batch.pos [idx] = j*n_batch + k;
|
||||
batch.n_seq_id[idx] = 1;
|
||||
batch.seq_id [idx][0] = seq;
|
||||
batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
|
||||
|
||||
n_outputs += batch.logits[idx] != 0;
|
||||
|
||||
}
|
||||
batch.n_tokens += batch_size;
|
||||
|
||||
// restore the original token in case it was set to BOS
|
||||
tokens[seq_start] = token_org;
|
||||
}
|
||||
}
|
||||
|
||||
for (int k = 0; k < batch_size; ++k) {
|
||||
const int idx = seq*n_ctx + k;
|
||||
// other ranks need to know batch info
|
||||
{
|
||||
if (n_world > 1) {
|
||||
meta.n_ctx = n_ctx;
|
||||
|
||||
batch.token [idx] = tokens[seq_start + k];
|
||||
batch.pos [idx] = j*n_batch + k;
|
||||
batch.n_seq_id[idx] = 1;
|
||||
batch.seq_id [idx][0] = seq;
|
||||
batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
|
||||
if (my_rank == 0) {
|
||||
// Required batch info: Operation scale, KV cache location, Logits calculation location
|
||||
meta.n_tokens = batch.n_tokens;
|
||||
meta.pos = batch.pos;
|
||||
meta.logits = batch.logits;
|
||||
|
||||
n_outputs += batch.logits[idx] != 0;
|
||||
meta.all_pos_0 = batch.all_pos_0;
|
||||
meta.all_pos_1 = batch.all_pos_1;
|
||||
|
||||
meta.n_outputs = n_outputs;
|
||||
meta.chunk_start_pos = start;
|
||||
|
||||
llama_send_meta(ctx, &meta);
|
||||
} else {
|
||||
if (llama_recv_meta(ctx, &meta) == -1) {
|
||||
LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank);
|
||||
return {tokens, -1.0, {}, {}};
|
||||
}
|
||||
if (!is_last_dev) {
|
||||
llama_send_meta(ctx, &meta);
|
||||
}
|
||||
}
|
||||
}
|
||||
batch.n_tokens += batch_size;
|
||||
|
||||
// restore the original token in case it was set to BOS
|
||||
tokens[seq_start] = token_org;
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
|
@ -620,14 +726,16 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
return {tokens, -1, logit_history, prob_history};
|
||||
}
|
||||
|
||||
if (num_batches > 1 && n_outputs > 0) {
|
||||
const auto * batch_logits = llama_get_logits(ctx);
|
||||
logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab);
|
||||
if (is_last_dev && num_batches > 1 ) {
|
||||
const int n_outputs_synced = meta.n_outputs;
|
||||
if (n_outputs_synced > 0) {
|
||||
const auto * batch_logits = llama_get_logits(ctx);
|
||||
logits.insert(logits.end(), batch_logits, batch_logits + n_outputs_synced * n_vocab);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (i == 0) {
|
||||
if (my_rank == 0 && i == 0) {
|
||||
llama_synchronize(ctx);
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
||||
|
@ -640,53 +748,60 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
LOG("%.2f minutes\n", total_seconds / 60.0);
|
||||
}
|
||||
|
||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
|
||||
if (is_last_dev) {
|
||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
|
||||
|
||||
int chunk_start_pos = meta.chunk_start_pos;
|
||||
llama_token * tokens_data = tokens.data() + chunk_start_pos + seq*n_ctx + first;
|
||||
if (!params.logits_file.empty()) {
|
||||
process_logits(logits_stream, n_vocab, all_logits,
|
||||
tokens_data, n_ctx - 1 - first,
|
||||
workers, log_probs, nll, nll2);
|
||||
} else {
|
||||
process_logits(n_vocab, all_logits,
|
||||
tokens_data, n_ctx - 1 - first,
|
||||
workers, nll, nll2,
|
||||
logit_history.data() + chunk_start_pos + seq*n_ctx + first,
|
||||
prob_history.data() + chunk_start_pos + seq*n_ctx + first);
|
||||
}
|
||||
count += n_ctx - first - 1;
|
||||
|
||||
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
|
||||
if (!params.logits_file.empty()) {
|
||||
process_logits(logits_stream, n_vocab, all_logits,
|
||||
tokens_data, n_ctx - 1 - first,
|
||||
workers, log_probs, nll, nll2);
|
||||
} else {
|
||||
process_logits(n_vocab, all_logits,
|
||||
tokens_data, n_ctx - 1 - first,
|
||||
workers, nll, nll2,
|
||||
logit_history.data() + start + seq*n_ctx + first,
|
||||
prob_history.data() + start + seq*n_ctx + first);
|
||||
}
|
||||
count += n_ctx - first - 1;
|
||||
|
||||
// perplexity is e^(average negative log-likelihood)
|
||||
if (params.ppl_output_type == 0) {
|
||||
LOG("[%d]%.4lf,", i + seq + 1, std::exp(nll / count));
|
||||
} else {
|
||||
double av = nll/count;
|
||||
double av2 = nll2/count - av*av;
|
||||
if (av2 > 0) av2 = sqrt(av2/(count-1));
|
||||
LOG("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
|
||||
// perplexity is e^(average negative log-likelihood)
|
||||
if (params.ppl_output_type == 0) {
|
||||
LOG("[%d]%.4lf,", i + seq + 1, std::exp(nll / count));
|
||||
} else {
|
||||
double av = nll/count;
|
||||
double av2 = nll2/count - av*av;
|
||||
if (av2 > 0) av2 = sqrt(av2/(count-1));
|
||||
LOG("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logits.clear();
|
||||
}
|
||||
LOG("\n");
|
||||
|
||||
nll2 /= count;
|
||||
nll /= count;
|
||||
const double ppl = exp(nll);
|
||||
nll2 -= nll * nll;
|
||||
if (nll2 > 0) {
|
||||
nll2 = sqrt(nll2/(count-1));
|
||||
LOG_INF("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
|
||||
} else {
|
||||
LOG_ERR("Unexpected negative standard deviation of log(prob)\n");
|
||||
if (is_last_dev) {
|
||||
nll2 /= count;
|
||||
nll /= count;
|
||||
const double ppl = exp(nll);
|
||||
nll2 -= nll * nll;
|
||||
|
||||
if (nll2 > 0) {
|
||||
nll2 = sqrt(nll2/(count-1));
|
||||
LOG_INF("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
|
||||
} else {
|
||||
LOG_ERR("Unexpected negative standard deviation of log(prob)\n");
|
||||
}
|
||||
|
||||
return {tokens, ppl, logit_history, prob_history};
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
return {tokens, ppl, logit_history, prob_history};
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int32_t n_batch, int32_t n_vocab) {
|
||||
int prev_outputs = 0;
|
||||
|
@ -1967,6 +2082,23 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
uint32_t n_world = params.n_world;
|
||||
uint32_t my_rank = params.rank;
|
||||
GGML_ASSERT(!(n_world == 1 && my_rank > 0));
|
||||
|
||||
// check if --n-layer-window and --world is matched
|
||||
if (my_rank == 0) {
|
||||
uint32_t non_zero_count = 0;
|
||||
size_t size = sizeof(params.n_layer_window) / sizeof(params.n_layer_window[0]);
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
if (params.n_layer_window[i] != 0) {
|
||||
++non_zero_count;
|
||||
}
|
||||
}
|
||||
GGML_ASSERT((non_zero_count == 0 || non_zero_count == n_world) \
|
||||
&& "Number of non-zero values in --n-layer-window must equal --world");
|
||||
}
|
||||
|
||||
gpt_init();
|
||||
|
||||
const int32_t n_ctx = params.n_ctx;
|
||||
|
@ -2008,6 +2140,12 @@ int main(int argc, char ** argv) {
|
|||
// load the model and apply lora adapter, if any
|
||||
llama_init_result llama_init = llama_init_from_gpt_params(params);
|
||||
|
||||
// update rank and world size if any devices removed
|
||||
my_rank = params.rank;
|
||||
n_world = params.n_world;
|
||||
|
||||
bool is_last_dev = (my_rank == n_world - 1);
|
||||
|
||||
llama_model * model = llama_init.model;
|
||||
llama_context * ctx = llama_init.context;
|
||||
if (model == NULL) {
|
||||
|
@ -2028,6 +2166,13 @@ int main(int argc, char ** argv) {
|
|||
LOG_INF("%s\n", gpt_params_get_system_info(params).c_str());
|
||||
}
|
||||
|
||||
char * stop_signal = nullptr;
|
||||
std::thread signal_thread;
|
||||
|
||||
if (my_rank != 0) {
|
||||
signal_thread = std::thread(llama_free_sockets, ctx, &stop_signal);
|
||||
}
|
||||
|
||||
struct results_perplexity results;
|
||||
if (params.hellaswag) {
|
||||
hellaswag_score(ctx, params);
|
||||
|
@ -2042,9 +2187,11 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
LOG("\n");
|
||||
llama_perf_context_print(ctx);
|
||||
|
||||
write_logfile(ctx, params, model, results);
|
||||
|
||||
if (is_last_dev) {
|
||||
llama_perf_context_print(ctx);
|
||||
write_logfile(ctx, params, model, results);
|
||||
}
|
||||
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
|
|
@ -48,6 +48,57 @@
|
|||
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
||||
#define LLAMA_STATE_SEQ_VERSION 2
|
||||
|
||||
|
||||
typedef int32_t llama_pos;
|
||||
typedef int32_t llama_seq_id;
|
||||
|
||||
struct sync_meta {
|
||||
// batch info
|
||||
int32_t n_tokens = 0;
|
||||
|
||||
int8_t * logits = nullptr;
|
||||
llama_pos * pos = nullptr;
|
||||
llama_pos all_pos_0;
|
||||
llama_pos all_pos_1;
|
||||
uint32_t n_ctx = 0;
|
||||
|
||||
int chunk_start_pos;
|
||||
int32_t n_outputs; // Used to pass the number of logits to be outputted
|
||||
|
||||
// signal to clear the kv cache
|
||||
bool clear_kv_cache= false;
|
||||
|
||||
// signal to remove a kv cache sequence
|
||||
bool kv_seq_rm = false;
|
||||
llama_seq_id rm_seq_id = 0;
|
||||
llama_pos rm_p0 = 0;
|
||||
llama_pos rm_p1 = 0;
|
||||
|
||||
// signal to add a kv cache sequence
|
||||
bool kv_seq_add = false;
|
||||
llama_seq_id add_seq_id = 0;
|
||||
llama_pos add_p0 = 0;
|
||||
llama_pos add_p1 = 0;
|
||||
llama_pos add_delta = 0;
|
||||
|
||||
// signal to copy a kv cache sequence
|
||||
bool kv_seq_cp = false;
|
||||
llama_seq_id cp_src_seq_id = 0;
|
||||
llama_seq_id cp_dst_seq_id = 0;
|
||||
llama_pos cp_p0 = 0;
|
||||
llama_pos cp_p1 = 0;
|
||||
|
||||
// signal to divide the kv cache range
|
||||
bool kv_seq_div = false;
|
||||
llama_seq_id div_seq_id = 0;
|
||||
llama_pos div_p0 = 0;
|
||||
llama_pos div_p1 = 0;
|
||||
int div_factor = 1;
|
||||
|
||||
// signal to transfer tokens_size
|
||||
size_t tokens_size = 0;
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
@ -451,6 +502,8 @@ extern "C" {
|
|||
|
||||
LLAMA_API void llama_init_sockets (struct llama_context * ctx, uint32_t n_world, uint32_t my_rank);
|
||||
LLAMA_API void llama_free_sockets (struct llama_context * ctx, char ** msg);
|
||||
LLAMA_API int llama_recv_meta (struct llama_context * ctx, struct sync_meta * meta);
|
||||
LLAMA_API void llama_send_meta (struct llama_context * ctx, struct sync_meta * meta);
|
||||
LLAMA_API int llama_gather_device_info(struct llama_context * ctx, struct device_info * dev_info_set);
|
||||
LLAMA_API int llama_send_device_info (struct llama_context * ctx, struct device_info * dev_info);
|
||||
LLAMA_API int llama_bcast_startup_args(struct llama_context * ctx, uint32_t rank, struct startup_args * args);
|
||||
|
|
|
@ -17866,46 +17866,13 @@ struct input_tensors {
|
|||
ggml_tensor * inp_pos;
|
||||
};
|
||||
|
||||
struct sync_meta {
|
||||
int32_t n_tokens = 0;
|
||||
llama_pos * pos = nullptr;
|
||||
llama_pos all_pos_0;
|
||||
llama_pos all_pos_1;
|
||||
uint32_t n_ctx = 0;
|
||||
|
||||
// signal to clear the kv cache
|
||||
bool clear_kv_cache = false;
|
||||
|
||||
// signal to remove a kv cache sequence
|
||||
bool kv_seq_rm = false;
|
||||
llama_seq_id rm_seq_id = 0;
|
||||
llama_pos rm_p0 = 0;
|
||||
llama_pos rm_p1 = 0;
|
||||
|
||||
// signal to add a kv cache sequence
|
||||
bool kv_seq_add = false;
|
||||
llama_seq_id add_seq_id = 0;
|
||||
llama_pos add_p0 = 0;
|
||||
llama_pos add_p1 = 0;
|
||||
llama_pos add_delta = 0;
|
||||
|
||||
// signal to copy a kv cache sequence
|
||||
bool kv_seq_cp = false;
|
||||
llama_seq_id cp_src_seq_id = 0;
|
||||
llama_seq_id cp_dst_seq_id = 0;
|
||||
llama_pos cp_p0 = 0;
|
||||
llama_pos cp_p1 = 0;
|
||||
|
||||
// signal to divide the kv cache range
|
||||
bool kv_seq_div = false;
|
||||
llama_seq_id div_seq_id = 0;
|
||||
llama_pos div_p0 = 0;
|
||||
llama_pos div_p1 = 0;
|
||||
int div_factor = 1;
|
||||
};
|
||||
|
||||
static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
||||
void llama_send_meta(llama_context * ctx, struct sync_meta * meta) {
|
||||
GGML_ASSERT(ctx != nullptr);
|
||||
GGML_ASSERT(meta != nullptr);
|
||||
|
||||
zmq::socket_t * send_socket = ctx->send_socket;
|
||||
GGML_ASSERT(send_socket != nullptr);
|
||||
|
||||
try {
|
||||
std::vector<zmq::message_t> send_msgs;
|
||||
|
||||
|
@ -17924,21 +17891,24 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
|||
send_msgs.emplace_back("all_pos_1", strlen("all_pos_1"));
|
||||
send_msgs.emplace_back(&(meta->all_pos_1), sizeof(meta->all_pos_1));
|
||||
|
||||
zmq::send_multipart(socket, send_msgs);
|
||||
if (!send_msgs.empty()) {
|
||||
zmq::send_multipart(*send_socket, send_msgs);
|
||||
}
|
||||
} catch (const zmq::error_t& e) {
|
||||
LLAMA_LOG_INFO("Failed to send meta data: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
||||
socket.set(zmq::sockopt::rcvtimeo, 1000);
|
||||
int llama_recv_meta(llama_context * ctx, struct sync_meta * meta) {
|
||||
ctx->recv_socket->set(zmq::sockopt::rcvtimeo, 1000);
|
||||
|
||||
std::vector<zmq::message_t> recv_msgs;
|
||||
if (!zmq::recv_multipart(socket, std::back_inserter(recv_msgs))) {
|
||||
|
||||
if (!zmq::recv_multipart(*(ctx->recv_socket), std::back_inserter(recv_msgs))) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
socket.set(zmq::sockopt::rcvtimeo, -1);
|
||||
ctx->recv_socket->set(zmq::sockopt::rcvtimeo, -1);
|
||||
|
||||
const std::string cmd = recv_msgs[0].to_string();
|
||||
size_t idx = 1;
|
||||
|
@ -18210,11 +18180,6 @@ static void manage_graph_tensors(struct ggml_cgraph * cgraph, int advice, bool f
|
|||
static int llama_decode_internal(
|
||||
llama_context & lctx,
|
||||
llama_batch & batch_all) { // TODO: rename back to batch
|
||||
|
||||
// llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.get_pos_max() + 1);
|
||||
|
||||
// // const llama_batch & batch_all = batch_allocr.batch;
|
||||
// llama_batch & batch_all = batch_allocr.batch;
|
||||
|
||||
lctx.is_encoding = false;
|
||||
|
||||
|
@ -18277,13 +18242,13 @@ static int llama_decode_internal(
|
|||
n_outputs = 1;
|
||||
}
|
||||
|
||||
// TODO:needs to be encapsulated into a function
|
||||
// prepare for send and receive of metadata
|
||||
sync_meta meta;
|
||||
meta.n_ctx = cparams.n_ctx;
|
||||
bool is_last_dev = (my_rank == n_world - 1);
|
||||
|
||||
if (my_rank != 0) {
|
||||
if (llama_recv_meta(*lctx.recv_socket, &meta) == -1) {
|
||||
if (llama_recv_meta(&lctx, &meta) == -1) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
@ -18343,7 +18308,7 @@ static int llama_decode_internal(
|
|||
meta.pos = batch_all.pos;
|
||||
meta.all_pos_0 = batch_all.all_pos_0;
|
||||
meta.all_pos_1 = batch_all.all_pos_1;
|
||||
llama_send_meta(*lctx.send_socket, &meta);
|
||||
llama_send_meta(&lctx, &meta);
|
||||
}
|
||||
|
||||
lctx.sbatch.from_batch(batch_all, n_embd,
|
||||
|
|
Loading…
Add table
Reference in a new issue