koboldcpp/otherarch/gptj_v1_main.cpp

145 lines
4.2 KiB
C++

#include "gptj_v1.cpp"
int main(int argc, char ** argv) {
ggml_v1_time_init();
const int64_t t_main_start_us = ggml_v1_time_us();
gpt_params params;
params.model = "models/gpt-j-6B/ggml-model.bin";
if (utils_gpt_params_parse(argc, argv, params) == false) {
return 1;
}
if (params.seed < 0) {
params.seed = time(NULL);
}
printf("%s: seed = %d\n", __func__, params.seed);
std::mt19937 rng(params.seed);
if (params.prompt.empty()) {
if( !isatty(STDIN_FILENO) ){
std::string line;
while( std::getline(std::cin, line) ){
params.prompt = params.prompt + "\n" + line;
}
} else {
params.prompt = utils_gpt_random_prompt(rng);
}
}
int64_t t_load_us = 0;
gpt_vocab vocab;
gptj_model_v1 model;
FileFormat file_format = FileFormat::GPTJ2;
// load the model
{
const int64_t t_start_us = ggml_v1_time_us();
if (legacy_gptj_model_load(params.model, model, vocab, file_format)!=ModelLoadResult::SUCCESS) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return 1;
}
t_load_us = ggml_v1_time_us() - t_start_us;
}
int n_past = 0;
int64_t t_sample_us = 0;
int64_t t_predict_us = 0;
std::vector<float> logits;
// tokenize the prompt
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
printf("\n");
std::vector<gpt_vocab::id> embd;
// determine the required inference memory per token:
size_t mem_per_token = 0;
legacy_gptj_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format);
for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
// predict
if (embd.size() > 0) {
const int64_t t_start_us = ggml_v1_time_us();
if (!legacy_gptj_eval(model, params.n_threads, n_past, embd, logits, mem_per_token,file_format)) {
printf("Failed to predict\n");
return 1;
}
t_predict_us += ggml_v1_time_us() - t_start_us;
}
n_past += embd.size();
embd.clear();
if (i >= embd_inp.size()) {
// sample next token
const int top_k = params.top_k;
const float top_p = params.top_p;
const float temp = params.temp;
const int n_vocab = model.hparams.n_vocab;
gpt_vocab::id id = 0;
{
const int64_t t_start_sample_us = ggml_v1_time_us();
id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
t_sample_us += ggml_v1_time_us() - t_start_sample_us;
}
// add it to the context
embd.push_back(id);
} else {
// if here, it means we are still processing the input prompt
for (int k = i; k < embd_inp.size(); k++) {
embd.push_back(embd_inp[k]);
if (embd.size() > params.n_batch) {
break;
}
}
i += embd.size() - 1;
}
// display text
for (auto id : embd) {
printf("%s", vocab.id_to_token[id].c_str());
}
fflush(stdout);
// end of text token
if (embd.back() == 50256) {
break;
}
}
// report timing
{
const int64_t t_main_end_us = ggml_v1_time_us();
printf("\n\n");
printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
}
ggml_v1_free(model.ctx);
return 0;
}