This commit is contained in:
Lizonghang 2024-10-23 09:42:32 +04:00
parent 6374743747
commit 2a01ff5fb1
10 changed files with 4725 additions and 1026 deletions

View file

@ -14,6 +14,7 @@
#include <sstream>
#include <string>
#include <vector>
#include <thread>
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h>
@ -141,6 +142,9 @@ int main(int argc, char ** argv) {
if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_MAIN, print_usage)) {
return 1;
}
const uint32_t n_world = params.n_world;
const uint32_t my_rank = params.rank;
GGML_ASSERT(!(n_world == 1 && my_rank > 0));
gpt_init();
@ -151,22 +155,6 @@ int main(int argc, char ** argv) {
console::init(params.simple_io, params.use_color);
atexit([]() { console::cleanup(); });
if (params.logits_all) {
LOG_ERR("************\n");
LOG_ERR("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__);
LOG_ERR("************\n\n");
return 0;
}
if (params.embedding) {
LOG_ERR("************\n");
LOG_ERR("%s: please use the 'embedding' tool for embedding calculations\n", __func__);
LOG_ERR("************\n\n");
return 0;
}
if (params.n_ctx != 0 && params.n_ctx < 8) {
LOG_WRN("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8;
@ -290,7 +278,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd_inp;
{
if (my_rank == 0) {
auto prompt = (params.conversation && params.enable_chat_template && !params.prompt.empty())
? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
: params.prompt;
@ -304,23 +292,23 @@ int main(int argc, char ** argv) {
LOG_DBG("prompt: \"%s\"\n", prompt.c_str());
LOG_DBG("tokens: %s\n", string_from(ctx, embd_inp).c_str());
}
// Should not run without any tokens
if (embd_inp.empty()) {
if (add_bos) {
embd_inp.push_back(llama_token_bos(model));
LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str());
} else {
LOG_ERR("input is empty\n");
return -1;
// should not run without any tokens
if (embd_inp.empty()) {
if (add_bos) {
embd_inp.push_back(llama_token_bos(model));
LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str());
} else {
LOG_ERR("input is empty\n");
return -1;
}
}
}
// Tokenize negative prompt
if ((int) embd_inp.size() > n_ctx - 4) {
LOG_ERR("%s: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
return 1;
// tokenize negative prompt
if ((int) embd_inp.size() > n_ctx - 4) {
LOG_ERR("%s: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
return 1;
}
}
// debug message about similarity of saved session, if applicable
@ -448,18 +436,18 @@ int main(int argc, char ** argv) {
}
}
smpl = gpt_sampler_init(model, sparams);
if (!smpl) {
LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
return 1;
if (my_rank == 0) {
smpl = gpt_sampler_init(model, sparams);
if (!smpl) {
LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
return 1;
}
LOG_INF("sampler seed: %u\n", gpt_sampler_get_seed(smpl));
LOG_INF("sampler params: \n%s\n", sparams.print().c_str());
LOG_INF("sampler chain: %s\n", gpt_sampler_print(smpl).c_str());
LOG_INF("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
}
LOG_INF("sampler seed: %u\n", gpt_sampler_get_seed(smpl));
LOG_INF("sampler params: \n%s\n", sparams.print().c_str());
LOG_INF("sampler chain: %s\n", gpt_sampler_print(smpl).c_str());
LOG_INF("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
// group-attention state
// number of grouped KV tokens so far (used only if params.grp_attn_n > 1)
int ga_i = 0;
@ -487,9 +475,7 @@ int main(int argc, char ** argv) {
" - If you want to submit another line, end your input with '\\'.\n";
}
LOG_INF("== Running in interactive mode. ==\n");
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
LOG_INF( " - Press Ctrl+C to interject at any time.\n");
#endif
LOG_INF( " - Enter quit or exit to quit chat.\n");
LOG_INF( "%s\n", control_message);
is_interacting = params.interactive_first;
@ -525,6 +511,8 @@ int main(int argc, char ** argv) {
}
if (llama_model_has_encoder(model)) {
throw std::runtime_error("this model is currently not supported");
int enc_input_size = embd_inp.size();
llama_token * enc_input_buf = embd_inp.data();
@ -542,9 +530,16 @@ int main(int argc, char ** argv) {
embd_inp.push_back(decoder_start_token_id);
}
char * stop_signal = nullptr;
std::thread signal_thread;
if (my_rank != 0) {
signal_thread = std::thread(llama_free_sockets, ctx, &stop_signal);
}
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict
if (!embd.empty()) {
if (!embd.empty() || my_rank != 0) {
// Note: (n_ctx - 4) here is to match the logic for commandline prompt handling via
// --prompt or --file which uses the same value.
int max_embd_size = n_ctx - 4;
@ -640,25 +635,22 @@ int main(int argc, char ** argv) {
}
}
for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
int n_eval = (int) embd.size() - i;
if (n_eval > params.n_batch) {
n_eval = params.n_batch;
if (my_rank == 0) {
for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
int n_eval = (int) embd.size() - i;
if (n_eval > params.n_batch) {
n_eval = params.n_batch;
}
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0)) != 0) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
}
n_past += n_eval;
}
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
}
n_past += n_eval;
LOG_DBG("n_past = %d\n", n_past);
// Display total tokens alongside total time
if (params.n_print > 0 && n_past % params.n_print == 0) {
LOG_DBG("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx);
} else {
llama_decode(ctx, llama_batch_get_one(embd.data(), 0, 0, 0));
if (stop_signal != nullptr && std::strcmp(stop_signal, "STOP") == 0) {
break;
}
}
@ -670,70 +662,70 @@ int main(int argc, char ** argv) {
embd.clear();
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// optionally save the session on first sample (for faster prompt loading next time)
if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
need_to_save_session = false;
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
if (my_rank == 0) {
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// optionally save the session on first sample (for faster prompt loading next time)
if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
need_to_save_session = false;
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
LOG_DBG("saved session to %s\n", path_session.c_str());
}
LOG_DBG("saved session to %s\n", path_session.c_str());
}
const llama_token id = gpt_sampler_sample(smpl, ctx, -1);
const llama_token id = gpt_sampler_sample(smpl, ctx, -1);
gpt_sampler_accept(smpl, id, /* accept_grammar= */ true);
gpt_sampler_accept(smpl, id, /* accept_grammar= */ true);
embd.push_back(id);
// LOG_DBG("last: %s\n", string_from(ctx, smpl->prev.to_vector()).c_str());
// echo this to console
input_echo = true;
embd.push_back(id);
// decrement remaining sampling budget
--n_remain;
// echo this to console
input_echo = true;
LOG_DBG("n_remain: %d\n", n_remain);
} else {
// some user input remains from prompt or interaction, forward it to processing
while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]);
// decrement remaining sampling budget
--n_remain;
// push the prompt in the sampling context in order to apply repetition penalties later
// for the prompt, we don't apply grammar rules
gpt_sampler_accept(smpl, embd_inp[n_consumed], /* accept_grammar= */ false);
LOG_DBG("n_remain: %d\n", n_remain);
} else {
// some user input remains from prompt or interaction, forward it to processing
LOG_DBG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]);
// push the prompt in the sampling context in order to apply repetition penalties later
// for the prompt, we don't apply grammar rules
gpt_sampler_accept(smpl, embd_inp[n_consumed], /* accept_grammar= */ false);
++n_consumed;
if ((int) embd.size() >= params.n_batch) {
break;
++n_consumed;
if ((int) embd.size() >= params.n_batch) {
break;
}
}
}
}
// display text
if (input_echo && display) {
for (auto id : embd) {
const std::string token_str = llama_token_to_piece(ctx, id, params.special);
if (my_rank == 0) {
if (input_echo && display) {
for (auto id : embd) {
const std::string token_str = llama_token_to_piece(ctx, id, params.special);
// Console/Stream Output
LOG("%s", token_str.c_str());
// Console/Stream Output
LOG("%s", token_str.c_str());
// Record Displayed Tokens To Log
// Note: Generated tokens are created one by one hence this check
if (embd.size() > 1) {
// Incoming Requested Tokens
input_tokens.push_back(id);
} else {
// Outgoing Generated Tokens
output_tokens.push_back(id);
output_ss << token_str;
// Record Displayed Tokens To Log
// Note: Generated tokens are created one by one hence this check
if (embd.size() > 1) {
// Incoming Requested Tokens
input_tokens.push_back(id);
} else {
// Outgoing Generated Tokens
output_tokens.push_back(id);
output_ss << token_str;
}
}
}
}
// reset color to default if there is no pending user input
if (input_echo && (int) embd_inp.size() == n_consumed) {
if (my_rank == 0 && input_echo && (int) embd_inp.size() == n_consumed) {
console::set_display(console::reset);
display = true;
}
@ -782,7 +774,7 @@ int main(int argc, char ** argv) {
}
// deal with end of generation tokens in interactive mode
if (llama_token_is_eog(model, gpt_sampler_last(smpl))) {
if (my_rank == 0 && llama_token_is_eog(model, gpt_sampler_last(smpl))) {
LOG_DBG("found an EOG token\n");
if (params.interactive) {
@ -840,6 +832,10 @@ int main(int argc, char ** argv) {
console::set_display(console::reset);
display = true;
if (buffer == "quit\n" || buffer == "exit\n") {
break;
}
// Add tokens to embd only if the input buffer is non-empty
// Entering a empty line lets the user pass control back
if (buffer.length() > 1) {
@ -924,19 +920,20 @@ int main(int argc, char ** argv) {
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
}
LOG("\n\n");
gpt_perf_print(ctx, smpl);
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
gpt_sampler_free(smpl);
if (my_rank == 0) {
LOG("\n\n");
gpt_perf_print(ctx, smpl);
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
gpt_sampler_free(smpl);
llama_free_sockets(ctx, &stop_signal);
}
if (my_rank != 0 && signal_thread.joinable()) {
signal_thread.join();
}
llama_free(ctx);
llama_free_model(model);
llama_backend_free();
ggml_threadpool_free(threadpool);
ggml_threadpool_free(threadpool_batch);
return 0;
}
}