added a token counting endpoint, set mmq as default

This commit is contained in:
Concedo 2023-08-24 20:41:49 +08:00
parent 81a0ef342c
commit b95a4ccb22
5 changed files with 72 additions and 28 deletions

View file

@ -338,6 +338,36 @@ static std::string FileFormatTokenizeID(int id, FileFormat file_format)
}
}
static std::vector<int> TokenizeString(const std::string & str_to_tokenize, FileFormat file_format)
{
std::vector<int> tokvec;
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_LLAMA)
{
if(file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 )
{
tokvec = ::llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, true);
}
else if (file_format == FileFormat::GGML)
{
tokvec = ::legacy_llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, true);
}
else if (file_format == FileFormat::GGJT_3)
{
tokvec = ::llama_v3_tokenize(llama_ctx_v3, str_to_tokenize, true);
}
else
{
tokvec = ::llama_tokenize(llama_ctx_v4, str_to_tokenize, true);
}
}
else
{
// tokenize the prompt
tokvec = ::gpt_tokenize(vocab, str_to_tokenize);
}
return tokvec;
}
static std::string RemoveBell(const std::string & input) //removes the bell character
{
std::string word2;
@ -965,6 +995,21 @@ bool gpttype_generate_abort()
return true;
}
int gpttype_token_count(const std::string & input)
{
if(debugmode==1)
{
printf("\nFileFormat: %d, Tokenizing: %s",file_format ,input.c_str());
}
auto toks = TokenizeString(input, file_format);
int tokcount = toks.size();
if(debugmode==1)
{
printf("\nTokens Counted: %d\n",tokcount);
}
return tokcount;
}
const std::string & gpttype_get_pending_output()
{
return concat_output;
@ -1018,32 +1063,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
}
// tokenize the prompt
std::vector<int> embd_inp;
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_LLAMA)
{
if(file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 )
{
embd_inp = ::llama_v2_tokenize(llama_ctx_v2, params.prompt, true);
}
else if (file_format == FileFormat::GGML)
{
embd_inp = ::legacy_llama_v2_tokenize(llama_ctx_v2, params.prompt, true);
}
else if (file_format == FileFormat::GGJT_3)
{
embd_inp = ::llama_v3_tokenize(llama_ctx_v3, params.prompt, true);
}
else
{
embd_inp = ::llama_tokenize(llama_ctx_v4, params.prompt, true);
}
}
else
{
// tokenize the prompt
embd_inp = ::gpt_tokenize(vocab, params.prompt);
}
std::vector<int> embd_inp = TokenizeString(params.prompt, file_format);
//truncate to front of the prompt if its too long
int32_t nctx = params.n_ctx;