mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-09 17:44:34 +00:00
init
This commit is contained in:
parent
6374743747
commit
2a01ff5fb1
10 changed files with 4725 additions and 1026 deletions
|
@ -14,6 +14,7 @@
|
|||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
#include <signal.h>
|
||||
|
@ -141,6 +142,9 @@ int main(int argc, char ** argv) {
|
|||
if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_MAIN, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
const uint32_t n_world = params.n_world;
|
||||
const uint32_t my_rank = params.rank;
|
||||
GGML_ASSERT(!(n_world == 1 && my_rank > 0));
|
||||
|
||||
gpt_init();
|
||||
|
||||
|
@ -151,22 +155,6 @@ int main(int argc, char ** argv) {
|
|||
console::init(params.simple_io, params.use_color);
|
||||
atexit([]() { console::cleanup(); });
|
||||
|
||||
if (params.logits_all) {
|
||||
LOG_ERR("************\n");
|
||||
LOG_ERR("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__);
|
||||
LOG_ERR("************\n\n");
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (params.embedding) {
|
||||
LOG_ERR("************\n");
|
||||
LOG_ERR("%s: please use the 'embedding' tool for embedding calculations\n", __func__);
|
||||
LOG_ERR("************\n\n");
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (params.n_ctx != 0 && params.n_ctx < 8) {
|
||||
LOG_WRN("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
|
||||
params.n_ctx = 8;
|
||||
|
@ -290,7 +278,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
std::vector<llama_token> embd_inp;
|
||||
|
||||
{
|
||||
if (my_rank == 0) {
|
||||
auto prompt = (params.conversation && params.enable_chat_template && !params.prompt.empty())
|
||||
? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
|
||||
: params.prompt;
|
||||
|
@ -304,23 +292,23 @@ int main(int argc, char ** argv) {
|
|||
|
||||
LOG_DBG("prompt: \"%s\"\n", prompt.c_str());
|
||||
LOG_DBG("tokens: %s\n", string_from(ctx, embd_inp).c_str());
|
||||
}
|
||||
|
||||
// Should not run without any tokens
|
||||
if (embd_inp.empty()) {
|
||||
if (add_bos) {
|
||||
embd_inp.push_back(llama_token_bos(model));
|
||||
LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str());
|
||||
} else {
|
||||
LOG_ERR("input is empty\n");
|
||||
return -1;
|
||||
// should not run without any tokens
|
||||
if (embd_inp.empty()) {
|
||||
if (add_bos) {
|
||||
embd_inp.push_back(llama_token_bos(model));
|
||||
LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str());
|
||||
} else {
|
||||
LOG_ERR("input is empty\n");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tokenize negative prompt
|
||||
if ((int) embd_inp.size() > n_ctx - 4) {
|
||||
LOG_ERR("%s: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
|
||||
return 1;
|
||||
// tokenize negative prompt
|
||||
if ((int) embd_inp.size() > n_ctx - 4) {
|
||||
LOG_ERR("%s: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// debug message about similarity of saved session, if applicable
|
||||
|
@ -448,18 +436,18 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
smpl = gpt_sampler_init(model, sparams);
|
||||
if (!smpl) {
|
||||
LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
|
||||
return 1;
|
||||
if (my_rank == 0) {
|
||||
smpl = gpt_sampler_init(model, sparams);
|
||||
if (!smpl) {
|
||||
LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
LOG_INF("sampler seed: %u\n", gpt_sampler_get_seed(smpl));
|
||||
LOG_INF("sampler params: \n%s\n", sparams.print().c_str());
|
||||
LOG_INF("sampler chain: %s\n", gpt_sampler_print(smpl).c_str());
|
||||
LOG_INF("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||
}
|
||||
|
||||
LOG_INF("sampler seed: %u\n", gpt_sampler_get_seed(smpl));
|
||||
LOG_INF("sampler params: \n%s\n", sparams.print().c_str());
|
||||
LOG_INF("sampler chain: %s\n", gpt_sampler_print(smpl).c_str());
|
||||
|
||||
LOG_INF("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||
|
||||
// group-attention state
|
||||
// number of grouped KV tokens so far (used only if params.grp_attn_n > 1)
|
||||
int ga_i = 0;
|
||||
|
@ -487,9 +475,7 @@ int main(int argc, char ** argv) {
|
|||
" - If you want to submit another line, end your input with '\\'.\n";
|
||||
}
|
||||
LOG_INF("== Running in interactive mode. ==\n");
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||
LOG_INF( " - Press Ctrl+C to interject at any time.\n");
|
||||
#endif
|
||||
LOG_INF( " - Enter quit or exit to quit chat.\n");
|
||||
LOG_INF( "%s\n", control_message);
|
||||
|
||||
is_interacting = params.interactive_first;
|
||||
|
@ -525,6 +511,8 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
if (llama_model_has_encoder(model)) {
|
||||
throw std::runtime_error("this model is currently not supported");
|
||||
|
||||
int enc_input_size = embd_inp.size();
|
||||
llama_token * enc_input_buf = embd_inp.data();
|
||||
|
||||
|
@ -542,9 +530,16 @@ int main(int argc, char ** argv) {
|
|||
embd_inp.push_back(decoder_start_token_id);
|
||||
}
|
||||
|
||||
char * stop_signal = nullptr;
|
||||
std::thread signal_thread;
|
||||
|
||||
if (my_rank != 0) {
|
||||
signal_thread = std::thread(llama_free_sockets, ctx, &stop_signal);
|
||||
}
|
||||
|
||||
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
||||
// predict
|
||||
if (!embd.empty()) {
|
||||
if (!embd.empty() || my_rank != 0) {
|
||||
// Note: (n_ctx - 4) here is to match the logic for commandline prompt handling via
|
||||
// --prompt or --file which uses the same value.
|
||||
int max_embd_size = n_ctx - 4;
|
||||
|
@ -640,25 +635,22 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
|
||||
int n_eval = (int) embd.size() - i;
|
||||
if (n_eval > params.n_batch) {
|
||||
n_eval = params.n_batch;
|
||||
if (my_rank == 0) {
|
||||
for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
|
||||
int n_eval = (int) embd.size() - i;
|
||||
if (n_eval > params.n_batch) {
|
||||
n_eval = params.n_batch;
|
||||
}
|
||||
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0)) != 0) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
n_past += n_eval;
|
||||
}
|
||||
|
||||
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
|
||||
|
||||
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
n_past += n_eval;
|
||||
|
||||
LOG_DBG("n_past = %d\n", n_past);
|
||||
// Display total tokens alongside total time
|
||||
if (params.n_print > 0 && n_past % params.n_print == 0) {
|
||||
LOG_DBG("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx);
|
||||
} else {
|
||||
llama_decode(ctx, llama_batch_get_one(embd.data(), 0, 0, 0));
|
||||
if (stop_signal != nullptr && std::strcmp(stop_signal, "STOP") == 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -670,70 +662,70 @@ int main(int argc, char ** argv) {
|
|||
|
||||
embd.clear();
|
||||
|
||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||
// optionally save the session on first sample (for faster prompt loading next time)
|
||||
if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
|
||||
need_to_save_session = false;
|
||||
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
||||
if (my_rank == 0) {
|
||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||
// optionally save the session on first sample (for faster prompt loading next time)
|
||||
if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
|
||||
need_to_save_session = false;
|
||||
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
||||
LOG_DBG("saved session to %s\n", path_session.c_str());
|
||||
}
|
||||
|
||||
LOG_DBG("saved session to %s\n", path_session.c_str());
|
||||
}
|
||||
const llama_token id = gpt_sampler_sample(smpl, ctx, -1);
|
||||
|
||||
const llama_token id = gpt_sampler_sample(smpl, ctx, -1);
|
||||
gpt_sampler_accept(smpl, id, /* accept_grammar= */ true);
|
||||
|
||||
gpt_sampler_accept(smpl, id, /* accept_grammar= */ true);
|
||||
embd.push_back(id);
|
||||
|
||||
// LOG_DBG("last: %s\n", string_from(ctx, smpl->prev.to_vector()).c_str());
|
||||
// echo this to console
|
||||
input_echo = true;
|
||||
|
||||
embd.push_back(id);
|
||||
// decrement remaining sampling budget
|
||||
--n_remain;
|
||||
|
||||
// echo this to console
|
||||
input_echo = true;
|
||||
LOG_DBG("n_remain: %d\n", n_remain);
|
||||
} else {
|
||||
// some user input remains from prompt or interaction, forward it to processing
|
||||
while ((int) embd_inp.size() > n_consumed) {
|
||||
embd.push_back(embd_inp[n_consumed]);
|
||||
|
||||
// decrement remaining sampling budget
|
||||
--n_remain;
|
||||
// push the prompt in the sampling context in order to apply repetition penalties later
|
||||
// for the prompt, we don't apply grammar rules
|
||||
gpt_sampler_accept(smpl, embd_inp[n_consumed], /* accept_grammar= */ false);
|
||||
|
||||
LOG_DBG("n_remain: %d\n", n_remain);
|
||||
} else {
|
||||
// some user input remains from prompt or interaction, forward it to processing
|
||||
LOG_DBG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
|
||||
while ((int) embd_inp.size() > n_consumed) {
|
||||
embd.push_back(embd_inp[n_consumed]);
|
||||
|
||||
// push the prompt in the sampling context in order to apply repetition penalties later
|
||||
// for the prompt, we don't apply grammar rules
|
||||
gpt_sampler_accept(smpl, embd_inp[n_consumed], /* accept_grammar= */ false);
|
||||
|
||||
++n_consumed;
|
||||
if ((int) embd.size() >= params.n_batch) {
|
||||
break;
|
||||
++n_consumed;
|
||||
if ((int) embd.size() >= params.n_batch) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// display text
|
||||
if (input_echo && display) {
|
||||
for (auto id : embd) {
|
||||
const std::string token_str = llama_token_to_piece(ctx, id, params.special);
|
||||
if (my_rank == 0) {
|
||||
if (input_echo && display) {
|
||||
for (auto id : embd) {
|
||||
const std::string token_str = llama_token_to_piece(ctx, id, params.special);
|
||||
|
||||
// Console/Stream Output
|
||||
LOG("%s", token_str.c_str());
|
||||
// Console/Stream Output
|
||||
LOG("%s", token_str.c_str());
|
||||
|
||||
// Record Displayed Tokens To Log
|
||||
// Note: Generated tokens are created one by one hence this check
|
||||
if (embd.size() > 1) {
|
||||
// Incoming Requested Tokens
|
||||
input_tokens.push_back(id);
|
||||
} else {
|
||||
// Outgoing Generated Tokens
|
||||
output_tokens.push_back(id);
|
||||
output_ss << token_str;
|
||||
// Record Displayed Tokens To Log
|
||||
// Note: Generated tokens are created one by one hence this check
|
||||
if (embd.size() > 1) {
|
||||
// Incoming Requested Tokens
|
||||
input_tokens.push_back(id);
|
||||
} else {
|
||||
// Outgoing Generated Tokens
|
||||
output_tokens.push_back(id);
|
||||
output_ss << token_str;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reset color to default if there is no pending user input
|
||||
if (input_echo && (int) embd_inp.size() == n_consumed) {
|
||||
if (my_rank == 0 && input_echo && (int) embd_inp.size() == n_consumed) {
|
||||
console::set_display(console::reset);
|
||||
display = true;
|
||||
}
|
||||
|
@ -782,7 +774,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// deal with end of generation tokens in interactive mode
|
||||
if (llama_token_is_eog(model, gpt_sampler_last(smpl))) {
|
||||
if (my_rank == 0 && llama_token_is_eog(model, gpt_sampler_last(smpl))) {
|
||||
LOG_DBG("found an EOG token\n");
|
||||
|
||||
if (params.interactive) {
|
||||
|
@ -840,6 +832,10 @@ int main(int argc, char ** argv) {
|
|||
console::set_display(console::reset);
|
||||
display = true;
|
||||
|
||||
if (buffer == "quit\n" || buffer == "exit\n") {
|
||||
break;
|
||||
}
|
||||
|
||||
// Add tokens to embd only if the input buffer is non-empty
|
||||
// Entering a empty line lets the user pass control back
|
||||
if (buffer.length() > 1) {
|
||||
|
@ -924,19 +920,20 @@ int main(int argc, char ** argv) {
|
|||
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
||||
}
|
||||
|
||||
LOG("\n\n");
|
||||
gpt_perf_print(ctx, smpl);
|
||||
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
|
||||
|
||||
gpt_sampler_free(smpl);
|
||||
|
||||
if (my_rank == 0) {
|
||||
LOG("\n\n");
|
||||
gpt_perf_print(ctx, smpl);
|
||||
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
|
||||
gpt_sampler_free(smpl);
|
||||
llama_free_sockets(ctx, &stop_signal);
|
||||
}
|
||||
if (my_rank != 0 && signal_thread.joinable()) {
|
||||
signal_thread.join();
|
||||
}
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
ggml_threadpool_free(threadpool);
|
||||
ggml_threadpool_free(threadpool_batch);
|
||||
|
||||
return 0;
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue