mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
added support for special tokens as stop sequences
This commit is contained in:
parent
b01820dec7
commit
3170284fc3
1 changed files with 49 additions and 4 deletions
|
@ -102,6 +102,7 @@ static size_t mem_per_token = 0;
|
|||
static std::vector<float> logits;
|
||||
static std::vector<int> smartcontext;
|
||||
static std::vector<std::string> stop_sequence;
|
||||
static std::vector<int> special_stop_sequence; //for stop sequences that don't have a string representation
|
||||
static std::vector<std::string> banned_tokens;
|
||||
static std::vector<int> banned_token_ids;
|
||||
static std::vector<llama_token_data> top_picks;
|
||||
|
@ -158,25 +159,40 @@ static std::string FileFormatTokenizeID(int id, FileFormat file_format)
|
|||
}
|
||||
}
|
||||
|
||||
static void TokenizeString(const std::string & str_to_tokenize, std::vector<int> & output_tokens, FileFormat file_format)
|
||||
static void TokenizeString(const std::string & str_to_tokenize, std::vector<int> & output_tokens, FileFormat file_format, bool add_bos=true)
|
||||
{
|
||||
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_GENERIC)
|
||||
{
|
||||
if(file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 )
|
||||
{
|
||||
output_tokens = ::llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, true);
|
||||
output_tokens = ::llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, add_bos);
|
||||
}
|
||||
else if (file_format == FileFormat::GGML)
|
||||
{
|
||||
output_tokens = ::legacy_llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, true);
|
||||
output_tokens = ::legacy_llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, add_bos);
|
||||
}
|
||||
else if (file_format == FileFormat::GGJT_3)
|
||||
{
|
||||
output_tokens = ::llama_v3_tokenize(llama_ctx_v3, str_to_tokenize, true);
|
||||
output_tokens = ::llama_v3_tokenize(llama_ctx_v3, str_to_tokenize, add_bos);
|
||||
}
|
||||
else
|
||||
{
|
||||
output_tokens = ::llama_tokenize(llama_ctx_v4, str_to_tokenize, true, true);
|
||||
if(add_bos)
|
||||
{
|
||||
llama_token bostoadd = llama_token_bos(&(llama_ctx_v4->model));
|
||||
if(output_tokens.size()==0)
|
||||
{
|
||||
output_tokens.push_back(bostoadd);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(output_tokens[0]!=bostoadd)
|
||||
{
|
||||
output_tokens.insert(output_tokens.begin(), 1, bostoadd);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
|
@ -1578,12 +1594,26 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
concat_output_mtx.unlock();
|
||||
last_stop_reason = stop_reason::OUT_OF_TOKENS;
|
||||
stop_sequence.clear();
|
||||
special_stop_sequence.clear();
|
||||
for(int x=0;x<stop_token_max;++x)
|
||||
{
|
||||
std::string stopper = inputs.stop_sequence[x];
|
||||
if(stopper!="")
|
||||
{
|
||||
stop_sequence.push_back(stopper);
|
||||
|
||||
//if it tokenizes to a single token, AND it's a single non-printable special token, use that
|
||||
std::vector<int> tmp;
|
||||
TokenizeString(stopper, tmp, file_format, false);
|
||||
if(tmp.size()==1) //tokenizes to exactly 1 special token
|
||||
{
|
||||
int specialid = tmp[0];
|
||||
std::string tokenizedstr = FileFormatTokenizeID(specialid, file_format);
|
||||
if(tokenizedstr=="") //must NOT have a text representation
|
||||
{
|
||||
special_stop_sequence.push_back(specialid);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2217,6 +2247,21 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
last_stop_reason = stop_reason::EOS_TOKEN_HIT;
|
||||
}
|
||||
|
||||
for (const auto &matched : special_stop_sequence)
|
||||
{
|
||||
if(id==matched)
|
||||
{
|
||||
stopper_unused_tokens = remaining_tokens;
|
||||
if(allow_regular_prints)
|
||||
{
|
||||
printf("\n(Special Stop Token Triggered! ID:%d)",matched);
|
||||
}
|
||||
remaining_tokens = 0;
|
||||
last_stop_reason = stop_reason::EOS_TOKEN_HIT;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &matched : stop_sequence)
|
||||
{
|
||||
if (concat_output.find(matched) != std::string::npos)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue