merged, added ability to render special tokens

This commit is contained in:
Concedo 2024-04-22 18:19:58 +08:00
commit b4d2031215
37 changed files with 335 additions and 7328 deletions

View file

@ -139,7 +139,7 @@ inline bool LogitsDuplicated(std::vector<float> & arr1, std::vector<float> & arr
}
static std::string FileFormatTokenizeID(int id, FileFormat file_format)
static std::string FileFormatTokenizeID(int id, FileFormat file_format, bool return_special = false)
{
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{
@ -151,7 +151,7 @@ static std::string FileFormatTokenizeID(int id, FileFormat file_format)
}
else if(file_format == FileFormat::GGUF_GENERIC)
{
return std::string(llama_token_to_str(llama_ctx_v4, id));
return std::string(llama_token_to_piece(llama_ctx_v4, id, return_special));
}
else
{
@ -285,7 +285,7 @@ static std::string get_tok_vec_str(std::vector<int> &embd)
std::string tmp = "";
for (auto id : embd)
{
tmp += "'" + FileFormatTokenizeID(id, file_format) + " (" + std::to_string(id) + ")', ";
tmp += "'" + FileFormatTokenizeID(id, file_format, true) + " (" + std::to_string(id) + ")', ";
}
::utreplace(tmp, "\n", "\\n");
return tmp;
@ -604,7 +604,7 @@ static void grammar_accept_token(FileFormat file_format, int32_t n_vocab, struct
}
GGML_ASSERT(false);
}
const std::string piece = FileFormatTokenizeID(token,file_format); //llama_token_to_str(ctx, token);
const std::string piece = FileFormatTokenizeID(token,file_format);
// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece.c_str(), grammar->partial_utf8);
@ -1984,7 +1984,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
printf("\n[First Run] Banning %zu token sequences...",banned_tokens.size());
for(int v=0;v<n_vocab;++v)
{
std::string word = FileFormatTokenizeID(v,file_format);
std::string word = FileFormatTokenizeID(v,file_format, true);
for(int i=0;i<banned_tokens.size();++i)
{
if (word.find(banned_tokens[i]) != std::string::npos)
@ -2171,7 +2171,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
lowestLogit = LowestLogit(logits);
}
if (!inputs.unban_tokens_rt)
if (!inputs.allow_eos_token)
{
// set the logit of the eos token to very low to avoid sampling it
logitsPtr[eosID] = lowestLogit;
@ -2204,7 +2204,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
for (auto id : embd)
{
std::string tokenizedstr = FileFormatTokenizeID(id, file_format);
std::string tokenizedstr = FileFormatTokenizeID(id, file_format, inputs.render_special);
if(stream_sse)
{
generated_tokens.push_back(tokenizedstr);
@ -2229,14 +2229,14 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
printf(" ");
}
firstloop = false;
std::string tokenizedstr = FileFormatTokenizeID(pick.id, file_format);
std::string tokenizedstr = FileFormatTokenizeID(pick.id, file_format, true);
::utreplace(tokenizedstr, "\n", "\\n");
printf("(%s %.2f%%)", RemoveBell(tokenizedstr).c_str(), pick.p*100);
}
printf("]\n");
}
if(inputs.unban_tokens_rt && id==eosID)
if(inputs.allow_eos_token && id==eosID)
{
stopper_unused_tokens = remaining_tokens;
if(allow_regular_prints)