fix for DRY segfault on unicode character substring tokenization

This commit is contained in:
Concedo 2024-09-08 18:25:00 +08:00
parent 2e74bd0327
commit c78690737c

View file

@ -313,6 +313,19 @@ static void print_tok_vec_str(std::vector<int> &vec)
printf("\n%s", get_tok_vec_str(vec).c_str());
}
bool allExtendedUnicode(const std::string& str) {
if(str.size()==0)
{
return false;
}
for (unsigned char c : str) {
if (c <= 127) {
return false;
}
}
return true;
}
// Find tokens that completely contain `str`, either as a single token, or as a sequence of tokens.
// It's important to use a hash map for head tokens because some models have many of them.
// For example, the Llama 3 tokenizer has 6570 tokens containing the period ('.') character.
@ -322,6 +335,7 @@ static void print_tok_vec_str(std::vector<int> &vec)
// tail tokens are generated by tokenizing the remainder.
// If max_tail_len is >= 0, the maximum token length of a tail sequence is clamped to this value.
static void GetOverlappingTokenSequences(const std::string& str, std::unordered_multimap<gpt_vocab::id, std::vector<gpt_vocab::id>>& token_sequences, int max_tail_len = -1) {
bool isAllExtendedUnicode = allExtendedUnicode(str);
for(int v=0;v<n_vocab;++v)
{
std::string word = FileFormatTokenizeID(v, file_format, true);
@ -355,7 +369,7 @@ static void GetOverlappingTokenSequences(const std::string& str, std::unordered_
break;
}
}
if (match) {
if (match && !isAllExtendedUnicode) {
// We matched to the end of the string. Since `str` is not contained in `word`,
// there must be trailing letters in `str`.
std::vector<gpt_vocab::id> tokenization;