#include "gptj_v1.cpp" int main(int argc, char ** argv) { ggml_v1_time_init(); const int64_t t_main_start_us = ggml_v1_time_us(); gpt_params params; params.model = "models/gpt-j-6B/ggml-model.bin"; if (utils_gpt_params_parse(argc, argv, params) == false) { return 1; } if (params.seed < 0) { params.seed = time(NULL); } printf("%s: seed = %d\n", __func__, params.seed); std::mt19937 rng(params.seed); if (params.prompt.empty()) { if( !isatty(STDIN_FILENO) ){ std::string line; while( std::getline(std::cin, line) ){ params.prompt = params.prompt + "\n" + line; } } else { params.prompt = utils_gpt_random_prompt(rng); } } int64_t t_load_us = 0; gpt_vocab vocab; gptj_model_v1 model; FileFormat file_format = FileFormat::GPTJ_2; // load the model { const int64_t t_start_us = ggml_v1_time_us(); if (legacy_gptj_model_load(params.model, model, vocab, file_format)!=ModelLoadResult::SUCCESS) { fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); return 1; } t_load_us = ggml_v1_time_us() - t_start_us; } int n_past = 0; int64_t t_sample_us = 0; int64_t t_predict_us = 0; std::vector logits; // tokenize the prompt std::vector embd_inp = ::gpt_tokenize(vocab, params.prompt); params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size()); printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); printf("\n"); std::vector embd; // determine the required inference memory per token: size_t mem_per_token = 0; legacy_gptj_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format); for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) { // predict if (embd.size() > 0) { const int64_t t_start_us = ggml_v1_time_us(); if (!legacy_gptj_eval(model, params.n_threads, n_past, embd, logits, mem_per_token,file_format)) { printf("Failed to predict\n"); return 1; } t_predict_us += ggml_v1_time_us() - t_start_us; } n_past += embd.size(); embd.clear(); if (i >= embd_inp.size()) { // sample next token const int top_k = params.top_k; const float top_p = params.top_p; const float temp = params.temp; const int n_vocab = model.hparams.n_vocab; gpt_vocab::id id = 0; { const int64_t t_start_sample_us = ggml_v1_time_us(); id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng); t_sample_us += ggml_v1_time_us() - t_start_sample_us; } // add it to the context embd.push_back(id); } else { // if here, it means we are still processing the input prompt for (int k = i; k < embd_inp.size(); k++) { embd.push_back(embd_inp[k]); if (embd.size() > params.n_batch) { break; } } i += embd.size() - 1; } // display text for (auto id : embd) { printf("%s", vocab.id_to_token[id].c_str()); } fflush(stdout); // end of text token if (embd.back() == 50256) { break; } } // report timing { const int64_t t_main_end_us = ggml_v1_time_us(); printf("\n\n"); printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token); printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f); printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f); printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past); printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); } ggml_v1_free(model.ctx); return 0; }