diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index b1e374981..ea9ecca3f 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -102,6 +102,7 @@ static size_t mem_per_token = 0; static std::vector logits; static std::vector smartcontext; static std::vector stop_sequence; +static std::vector special_stop_sequence; //for stop sequences that don't have a string representation static std::vector banned_tokens; static std::vector banned_token_ids; static std::vector top_picks; @@ -158,25 +159,40 @@ static std::string FileFormatTokenizeID(int id, FileFormat file_format) } } -static void TokenizeString(const std::string & str_to_tokenize, std::vector & output_tokens, FileFormat file_format) +static void TokenizeString(const std::string & str_to_tokenize, std::vector & output_tokens, FileFormat file_format, bool add_bos=true) { if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_GENERIC) { if(file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 ) { - output_tokens = ::llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, true); + output_tokens = ::llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, add_bos); } else if (file_format == FileFormat::GGML) { - output_tokens = ::legacy_llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, true); + output_tokens = ::legacy_llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, add_bos); } else if (file_format == FileFormat::GGJT_3) { - output_tokens = ::llama_v3_tokenize(llama_ctx_v3, str_to_tokenize, true); + output_tokens = ::llama_v3_tokenize(llama_ctx_v3, str_to_tokenize, add_bos); } else { output_tokens = ::llama_tokenize(llama_ctx_v4, str_to_tokenize, true, true); + if(add_bos) + { + llama_token bostoadd = llama_token_bos(&(llama_ctx_v4->model)); + if(output_tokens.size()==0) + { + output_tokens.push_back(bostoadd); + } + else + { + if(output_tokens[0]!=bostoadd) + { + output_tokens.insert(output_tokens.begin(), 1, bostoadd); + } + } + } } } else @@ -1578,12 +1594,26 @@ generation_outputs gpttype_generate(const generation_inputs inputs) concat_output_mtx.unlock(); last_stop_reason = stop_reason::OUT_OF_TOKENS; stop_sequence.clear(); + special_stop_sequence.clear(); for(int x=0;x tmp; + TokenizeString(stopper, tmp, file_format, false); + if(tmp.size()==1) //tokenizes to exactly 1 special token + { + int specialid = tmp[0]; + std::string tokenizedstr = FileFormatTokenizeID(specialid, file_format); + if(tokenizedstr=="") //must NOT have a text representation + { + special_stop_sequence.push_back(specialid); + } + } } } @@ -2217,6 +2247,21 @@ generation_outputs gpttype_generate(const generation_inputs inputs) last_stop_reason = stop_reason::EOS_TOKEN_HIT; } + for (const auto &matched : special_stop_sequence) + { + if(id==matched) + { + stopper_unused_tokens = remaining_tokens; + if(allow_regular_prints) + { + printf("\n(Special Stop Token Triggered! ID:%d)",matched); + } + remaining_tokens = 0; + last_stop_reason = stop_reason::EOS_TOKEN_HIT; + break; + } + } + for (const auto &matched : stop_sequence) { if (concat_output.find(matched) != std::string::npos)