Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.github/workflows/build.yml
#	README.md
#	examples/CMakeLists.txt
#	examples/batched/batched.cpp
#	examples/gritlm/gritlm.cpp
#	examples/llama.android/llama/build.gradle.kts
#	examples/main/README.md
#	examples/retrieval/retrieval.cpp
#	examples/server/CMakeLists.txt
#	examples/server/README.md
#	ggml/CMakeLists.txt
#	ggml/src/ggml-cpu/CMakeLists.txt
#	ggml/src/ggml.c
#	scripts/compare-commits.sh
#	scripts/sync-ggml.last
#	tests/CMakeLists.txt
#	tests/test-backend-ops.cpp
#	tests/test-chat-template.cpp
#	tests/test-sampling.cpp
This commit is contained in:
Concedo 2024-12-19 11:57:43 +08:00
commit ee486bad3e
59 changed files with 20531 additions and 13185 deletions

View file

@ -1396,19 +1396,15 @@ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab
// penalties
struct llama_sampler_penalties {
const int32_t n_vocab;
const llama_token special_eos_id;
const llama_token linefeed_id;
const int32_t penalty_last_n;
const float penalty_repeat;
const float penalty_freq;
const float penalty_present;
const bool penalize_nl;
const bool ignore_eos;
ring_buffer<llama_token> prev;
// a frequency map to count token occurrences
std::unordered_map<llama_token, int> token_count;
};
static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
@ -1421,76 +1417,50 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to
return;
}
ctx->token_count[token]++;
// if the ring buffer is full, remove the oldest token
if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
const auto old = ctx->prev.front();
ctx->token_count[old]--;
if (ctx->token_count[old] == 0) {
ctx->token_count.erase(old);
}
}
ctx->prev.push_back(token);
#if 0
// sanity check
std::unordered_map<llama_token, int> tmp;
for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
tmp[ctx->prev.rat(i)]++;
}
assert(ctx->token_count == tmp);
#endif
}
static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
if (ctx->ignore_eos) {
assert(ctx->special_eos_id >= 0);
// optimistically check if the candidates are not yet sorted/shuffled/truncated
if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
cur_p->data[ctx->special_eos_id].logit = -INFINITY;
} else {
// else, search for the special EOS token
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].id == ctx->special_eos_id) {
cur_p->data[i].logit = -INFINITY;
break;
}
}
}
}
if ((ctx->penalty_last_n == 0) ||
(ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
return;
}
bool nl_found = false;
size_t nl_idx = 0;
float nl_logit = -INFINITY;
if (!ctx->penalize_nl) {
assert(ctx->linefeed_id >= 0);
// optimistically check if the candidates are not yet sorted/shuffled/truncated
if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
nl_found = true;
nl_idx = ctx->linefeed_id;
nl_logit = cur_p->data[ctx->linefeed_id].logit;
} else {
// else, search for the linefeed token
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].id == ctx->linefeed_id) {
nl_found = true;
nl_idx = i;
nl_logit = cur_p->data[i].logit;
break;
}
}
}
}
// Create a frequency map to count occurrences of each token in last_tokens
// TODO: optimize this by maintaining the token count in the sampler context
using llama_token_cnt = std::unordered_map<llama_token, int>;
llama_token_cnt token_count;
for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
token_count[ctx->prev.rat(i)]++;
}
// Apply frequency and presence penalties to the cur_p
for (size_t i = 0; i < cur_p->size; ++i) {
const auto token_iter = token_count.find(cur_p->data[i].id);
if (token_iter == token_count.end()) {
const auto token_iter = ctx->token_count.find(cur_p->data[i].id);
if (token_iter == ctx->token_count.end()) {
continue;
}
const int count = token_iter->second;
assert(count > 0 && count <= ctx->penalty_last_n);
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
if (cur_p->data[i].logit <= 0) {
@ -1503,30 +1473,21 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
}
cur_p->sorted = false;
if (!ctx->penalize_nl && nl_found) {
// restore the logit of the newline token if it was penalized
cur_p->data[nl_idx].logit = nl_logit;
}
}
static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
ctx->prev.clear();
ctx->token_count.clear();
}
static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
auto * result = llama_sampler_init_penalties(
ctx->n_vocab,
ctx->special_eos_id,
ctx->linefeed_id,
ctx->penalty_last_n,
ctx->penalty_repeat,
ctx->penalty_freq,
ctx->penalty_present,
ctx->penalize_nl,
ctx->ignore_eos);
ctx->penalty_present);
// copy the state
{
@ -1552,38 +1513,21 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
};
struct llama_sampler * llama_sampler_init_penalties(
int32_t n_vocab,
llama_token special_eos_id,
llama_token linefeed_id,
int32_t penalty_last_n,
float penalty_repeat,
float penalty_freq,
float penalty_present,
bool penalize_nl,
bool ignore_eos) {
if (linefeed_id == LLAMA_TOKEN_NULL) {
penalize_nl = true;
}
if (special_eos_id == LLAMA_TOKEN_NULL) {
ignore_eos = false;
}
float penalty_present) {
penalty_last_n = std::max(penalty_last_n, 0);
return new llama_sampler {
/* .iface = */ &llama_sampler_penalties_i,
/* .ctx = */ new llama_sampler_penalties {
/* .n_vocab = */ n_vocab,
/* .special_eos_id = */ special_eos_id,
/* .linefeed_id = */ linefeed_id,
/* .penalty_last_n = */ penalty_last_n,
/* .penalty_repeat = */ penalty_repeat,
/* .penalty_freq = */ penalty_freq,
/* .penalty_present = */ penalty_present,
/* .penalize_nl = */ penalize_nl,
/* .ignore_eos = */ ignore_eos,
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
/* .token_count = */ {},
},
};
}
@ -1611,7 +1555,8 @@ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std
if (word.find(str) != std::string::npos) {
token_sequences.emplace(token_id, std::vector<llama_token>());
} else {
size_t word_len = word.size(), str_len = str.size();
size_t word_len = word.size();
size_t str_len = str.size();
size_t pos = -1;
while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
bool match = true;

View file

@ -965,7 +965,7 @@ struct llm_tokenizer_wpm_session {
std::vector<std::string> words(1, "");
for (const uint32_t cpt : cpts_nfd) {
const auto flags = unicode_cpt_flags(cpt);
const auto flags = unicode_cpt_flags_from_cpt(cpt);
if (flags.is_whitespace) {
if (words.back().size()) { // finish previous word if any
@ -2127,6 +2127,10 @@ int32_t llama_detokenize_impl(
int32_t text_len_max,
bool remove_special,
bool unparse_special) {
if (vocab.type == LLAMA_VOCAB_TYPE_NONE) {
return 0;
}
GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
int32_t avail = text_len_max;

File diff suppressed because it is too large Load diff

View file

@ -71,15 +71,15 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
throw std::invalid_argument("failed to convert utf8 to codepoint");
}
//static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cp) {
//static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cpt) {
// std::vector<uint16_t> result;
// if (/* 0x0000 <= cp && */ cp <= 0xffff) {
// result.emplace_back(cp);
// if (/* 0x0000 <= cpt && */ cpt <= 0xffff) {
// result.emplace_back(cpt);
// return result;
// }
// if (0x10000 <= cp && cp <= 0x10ffff) {
// result.emplace_back(0xd800 | ((cp - 0x10000) >> 10));
// result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff));
// if (0x10000 <= cpt && cpt <= 0x10ffff) {
// result.emplace_back(0xd800 | ((cpt - 0x10000) >> 10));
// result.emplace_back(0xdc00 | ((cpt - 0x10000) & 0x03ff));
// return result;
// }
// throw std::invalid_argument("failed to convert codepoint to utf16");
@ -120,8 +120,8 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
// return result;
//}
static std::vector<codepoint_flags> unicode_cpt_flags_array() {
std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
static std::vector<unicode_cpt_flags> unicode_cpt_flags_array() {
std::vector<unicode_cpt_flags> cpt_flags(MAX_CODEPOINTS, unicode_cpt_flags::UNDEFINED);
assert (unicode_ranges_flags.begin()[0].first == 0);
assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS);
@ -261,8 +261,8 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
};
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
};
size_t _prev_end = offset_ini;
@ -379,8 +379,8 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
};
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
};
size_t _prev_end = offset_ini;
@ -580,29 +580,29 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text,
// interface
//
std::string unicode_cpt_to_utf8(uint32_t cp) {
std::string unicode_cpt_to_utf8(uint32_t cpt) {
std::string result;
if (/* 0x00 <= cp && */ cp <= 0x7f) {
result.push_back(cp);
if (/* 0x00 <= cpt && */ cpt <= 0x7f) {
result.push_back(cpt);
return result;
}
if (0x80 <= cp && cp <= 0x7ff) {
result.push_back(0xc0 | ((cp >> 6) & 0x1f));
result.push_back(0x80 | (cp & 0x3f));
if (0x80 <= cpt && cpt <= 0x7ff) {
result.push_back(0xc0 | ((cpt >> 6) & 0x1f));
result.push_back(0x80 | (cpt & 0x3f));
return result;
}
if (0x800 <= cp && cp <= 0xffff) {
result.push_back(0xe0 | ((cp >> 12) & 0x0f));
result.push_back(0x80 | ((cp >> 6) & 0x3f));
result.push_back(0x80 | (cp & 0x3f));
if (0x800 <= cpt && cpt <= 0xffff) {
result.push_back(0xe0 | ((cpt >> 12) & 0x0f));
result.push_back(0x80 | ((cpt >> 6) & 0x3f));
result.push_back(0x80 | (cpt & 0x3f));
return result;
}
if (0x10000 <= cp && cp <= 0x10ffff) {
result.push_back(0xf0 | ((cp >> 18) & 0x07));
result.push_back(0x80 | ((cp >> 12) & 0x3f));
result.push_back(0x80 | ((cp >> 6) & 0x3f));
result.push_back(0x80 | (cp & 0x3f));
if (0x10000 <= cpt && cpt <= 0x10ffff) {
result.push_back(0xf0 | ((cpt >> 18) & 0x07));
result.push_back(0x80 | ((cpt >> 12) & 0x3f));
result.push_back(0x80 | ((cpt >> 6) & 0x3f));
result.push_back(0x80 | (cpt & 0x3f));
return result;
}
@ -632,19 +632,19 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
return result;
}
codepoint_flags unicode_cpt_flags(const uint32_t cp) {
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
unicode_cpt_flags unicode_cpt_flags_from_cpt(const uint32_t cpt) {
static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
static const auto cpt_flags = unicode_cpt_flags_array();
return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
return cpt < cpt_flags.size() ? cpt_flags[cpt] : undef;
}
codepoint_flags unicode_cpt_flags(const std::string & utf8) {
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8) {
static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
if (utf8.empty()) {
return undef; // undefined
}
size_t offset = 0;
return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset));
return unicode_cpt_flags_from_cpt(unicode_cpt_from_utf8(utf8, offset));
}
std::string unicode_byte_to_utf8(uint8_t byte) {
@ -657,41 +657,41 @@ uint8_t unicode_utf8_to_byte(const std::string & utf8) {
return map.at(utf8);
}
uint32_t unicode_tolower(uint32_t cp) {
uint32_t unicode_tolower(uint32_t cpt) {
// binary search
auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cp,
auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cpt,
[](const std::pair<uint32_t, uint32_t> & pair, uint32_t value) {
return pair.first < value;
});
if (it != unicode_map_lowercase.end() && it->first == cp) {
if (it != unicode_map_lowercase.end() && it->first == cpt) {
return it->second;
}
return cp; // Return the original code point if no lowercase mapping is found
return cpt; // Return the original code point if no lowercase mapping is found
}
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
// unicode categories
static const std::map<std::string, int> k_ucat_enum = {
{ "\\p{N}", codepoint_flags::NUMBER },
{ "\\p{L}", codepoint_flags::LETTER },
{ "\\p{P}", codepoint_flags::PUNCTUATION },
{ "\\p{N}", unicode_cpt_flags::NUMBER },
{ "\\p{L}", unicode_cpt_flags::LETTER },
{ "\\p{P}", unicode_cpt_flags::PUNCTUATION },
};
static const std::map<int, int> k_ucat_cpt = {
{ codepoint_flags::NUMBER, 0xD1 },
{ codepoint_flags::LETTER, 0xD2 },
{ codepoint_flags::PUNCTUATION, 0xD3 },
{ unicode_cpt_flags::NUMBER, 0xD1 },
{ unicode_cpt_flags::LETTER, 0xD2 },
{ unicode_cpt_flags::PUNCTUATION, 0xD3 },
};
static const std::map<int, std::string> k_ucat_map = {
{ codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
{ codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
{ unicode_cpt_flags::NUMBER, "\x30-\x39" }, // 0-9
{ unicode_cpt_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
};
// compute collapsed codepoints only if needed by at least one regex
bool need_collapse = false;
for (auto & regex_expr : regex_exprs) {
for (const auto & regex_expr : regex_exprs) {
// search for unicode categories
for (const auto & ucat : k_ucat_enum) {
if (std::string::npos != regex_expr.find(ucat.first)) {
@ -717,7 +717,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
continue;
}
const auto flags = unicode_cpt_flags(cpts[i]);
const auto flags = unicode_cpt_flags_from_cpt(cpts[i]);
if (flags.is_whitespace) {
//NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
@ -733,7 +733,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
std::vector<size_t> bpe_offsets = { cpts.size() };
for (auto & regex_expr : regex_exprs) {
for (const auto & regex_expr : regex_exprs) {
// first, see if we have an efficient custom regex implementation
auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets);
@ -747,7 +747,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
// if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
// with the corresponding collapsed representation
bool use_collapsed = false;
for (auto & ucat : k_ucat_enum) {
for (const auto & ucat : k_ucat_enum) {
if (std::string::npos != regex_expr.find(ucat.first)) {
use_collapsed = true;
break;
@ -813,7 +813,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
std::wstring wtext(cpts.begin(), cpts.end());
for (size_t i = 0; i < wtext.size(); ++i) {
if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) {
if (wtext[i] > 0x7F && unicode_cpt_flags_from_cpt(wtext[i]).is_whitespace) {
wtext[i] = 0x0B;
}
}

View file

@ -4,9 +4,7 @@
#include <string>
#include <vector>
// TODO: prefix all symbols with "llama_"
struct codepoint_flags {
struct unicode_cpt_flags {
enum {
UNDEFINED = 0x0001,
NUMBER = 0x0002, // regex: \p{N}
@ -35,7 +33,7 @@ struct codepoint_flags {
uint16_t is_nfd : 1;
// decode from uint16
inline codepoint_flags(const uint16_t flags=0) {
inline unicode_cpt_flags(const uint16_t flags = 0) {
*reinterpret_cast<uint16_t*>(this) = flags;
}
@ -50,18 +48,19 @@ struct codepoint_flags {
size_t unicode_len_utf8(char src);
std::string unicode_cpt_to_utf8(uint32_t cp);
uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
std::string unicode_cpt_to_utf8 (uint32_t cpt);
uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
codepoint_flags unicode_cpt_flags(const uint32_t cp);
codepoint_flags unicode_cpt_flags(const std::string & utf8);
unicode_cpt_flags unicode_cpt_flags_from_cpt (uint32_t cpt);
unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8);
std::string unicode_byte_to_utf8(uint8_t byte);
uint8_t unicode_utf8_to_byte(const std::string & utf8);
uint8_t unicode_utf8_to_byte(const std::string & utf8);
uint32_t unicode_tolower(uint32_t cp);
uint32_t unicode_tolower(uint32_t cpt);
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);