mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 09:04:36 +00:00
145 lines
4.2 KiB
C++
145 lines
4.2 KiB
C++
#include "gptj_v1.cpp"
|
|
|
|
int main(int argc, char ** argv) {
|
|
ggml_v1_time_init();
|
|
const int64_t t_main_start_us = ggml_v1_time_us();
|
|
|
|
gpt_params params;
|
|
params.model = "models/gpt-j-6B/ggml-model.bin";
|
|
|
|
if (utils_gpt_params_parse(argc, argv, params) == false) {
|
|
return 1;
|
|
}
|
|
|
|
if (params.seed < 0) {
|
|
params.seed = time(NULL);
|
|
}
|
|
|
|
printf("%s: seed = %d\n", __func__, params.seed);
|
|
|
|
std::mt19937 rng(params.seed);
|
|
if (params.prompt.empty()) {
|
|
if( !isatty(STDIN_FILENO) ){
|
|
std::string line;
|
|
while( std::getline(std::cin, line) ){
|
|
params.prompt = params.prompt + "\n" + line;
|
|
}
|
|
} else {
|
|
params.prompt = utils_gpt_random_prompt(rng);
|
|
}
|
|
}
|
|
|
|
int64_t t_load_us = 0;
|
|
|
|
gpt_vocab vocab;
|
|
gptj_model_v1 model;
|
|
FileFormat file_format = FileFormat::GPTJ_2;
|
|
|
|
// load the model
|
|
{
|
|
const int64_t t_start_us = ggml_v1_time_us();
|
|
|
|
if (legacy_gptj_model_load(params.model, model, vocab, file_format)!=ModelLoadResult::SUCCESS) {
|
|
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
|
|
return 1;
|
|
}
|
|
|
|
t_load_us = ggml_v1_time_us() - t_start_us;
|
|
}
|
|
|
|
int n_past = 0;
|
|
|
|
int64_t t_sample_us = 0;
|
|
int64_t t_predict_us = 0;
|
|
|
|
std::vector<float> logits;
|
|
|
|
// tokenize the prompt
|
|
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
|
|
|
|
params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
|
|
|
|
printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
|
|
printf("\n");
|
|
|
|
std::vector<gpt_vocab::id> embd;
|
|
|
|
// determine the required inference memory per token:
|
|
size_t mem_per_token = 0;
|
|
legacy_gptj_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format);
|
|
|
|
for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
|
|
// predict
|
|
if (embd.size() > 0) {
|
|
const int64_t t_start_us = ggml_v1_time_us();
|
|
|
|
if (!legacy_gptj_eval(model, params.n_threads, n_past, embd, logits, mem_per_token,file_format)) {
|
|
printf("Failed to predict\n");
|
|
return 1;
|
|
}
|
|
|
|
t_predict_us += ggml_v1_time_us() - t_start_us;
|
|
}
|
|
|
|
n_past += embd.size();
|
|
embd.clear();
|
|
|
|
if (i >= embd_inp.size()) {
|
|
// sample next token
|
|
const int top_k = params.top_k;
|
|
const float top_p = params.top_p;
|
|
const float temp = params.temp;
|
|
|
|
const int n_vocab = model.hparams.n_vocab;
|
|
|
|
gpt_vocab::id id = 0;
|
|
|
|
{
|
|
const int64_t t_start_sample_us = ggml_v1_time_us();
|
|
|
|
id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
|
|
|
|
t_sample_us += ggml_v1_time_us() - t_start_sample_us;
|
|
}
|
|
|
|
// add it to the context
|
|
embd.push_back(id);
|
|
} else {
|
|
// if here, it means we are still processing the input prompt
|
|
for (int k = i; k < embd_inp.size(); k++) {
|
|
embd.push_back(embd_inp[k]);
|
|
if (embd.size() > params.n_batch) {
|
|
break;
|
|
}
|
|
}
|
|
i += embd.size() - 1;
|
|
}
|
|
|
|
// display text
|
|
for (auto id : embd) {
|
|
printf("%s", vocab.id_to_token[id].c_str());
|
|
}
|
|
fflush(stdout);
|
|
|
|
// end of text token
|
|
if (embd.back() == 50256) {
|
|
break;
|
|
}
|
|
}
|
|
|
|
// report timing
|
|
{
|
|
const int64_t t_main_end_us = ggml_v1_time_us();
|
|
|
|
printf("\n\n");
|
|
printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
|
|
printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
|
|
printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
|
|
printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);
|
|
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
|
|
}
|
|
|
|
ggml_v1_free(model.ctx);
|
|
|
|
return 0;
|
|
}
|