merged, added ability to render special tokens

This commit is contained in:
Concedo 2024-04-22 18:19:58 +08:00
commit b4d2031215
37 changed files with 335 additions and 7328 deletions

107
llama.cpp
View file

@ -1626,12 +1626,12 @@ struct llama_mlock {
};
using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;
static std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) {
static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
std::vector<char> result(8, 0);
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special);
GGML_ASSERT(check == -n_tokens);
}
else {
@ -2146,7 +2146,7 @@ struct llama_vocab {
id special_prefix_id = -1;
id special_suffix_id = -1;
id special_middle_id = -1;
id special_eot_id = -1;
id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
bool add_space_prefix = true;
int find_bpe_rank(std::string token_left, std::string token_right) const {
@ -3814,7 +3814,7 @@ static void llm_load_hparams(
switch (hparams.n_layer) {
case 22: model.type = e_model::MODEL_1B; break;
case 26: model.type = e_model::MODEL_3B; break;
case 32: model.type = e_model::MODEL_7B; break;
case 32: model.type = hparams.n_head == hparams.n_head_kv ? e_model::MODEL_7B : e_model::MODEL_8B; break; // LLaMa 8B v3 uses GQA
case 40: model.type = e_model::MODEL_13B; break;
case 48: model.type = e_model::MODEL_34B; break;
case 60: model.type = e_model::MODEL_30B; break;
@ -4224,7 +4224,10 @@ static void llm_load_vocab(
vocab.special_prefix_id = 67;
vocab.special_suffix_id = 69;
vocab.special_middle_id = 68;
vocab.special_eot_id = 70;
// TODO: this is not EOT, it is "file separator" token, needs fix
// https://huggingface.co/google/codegemma-7b-it/blob/9b1d9231388358c04d90bd003458f5070d97db44/tokenizer_config.json#L565-L572
//vocab.special_eot_id = 70;
vocab.special_eot_id = 107;
}
}
@ -4371,6 +4374,7 @@ static void llm_load_vocab(
{ LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id },
{ LLM_KV_TOKENIZER_EOT_ID, vocab.special_eot_id },
};
for (const auto & it : special_token_types) {
const std::string & key = kv(std::get<0>(it));
int32_t & id = std::get<1>(it);
@ -4385,7 +4389,6 @@ static void llm_load_vocab(
} else {
id = new_id;
}
}
// Handle add_bos_token and add_eos_token
@ -4399,6 +4402,27 @@ static void llm_load_vocab(
vocab.special_add_eos = int(temp);
}
}
// find EOT token: "<|eot_id|>", "<|im_emd|>", "<end_of_turn>", etc.
//
// TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOT_ID
// for now, we apply this workaround to find the EOT token based on its text
if (vocab.special_eot_id == -1) {
for (const auto & t : vocab.token_to_id) {
if (
// TODO: gemma "<end_of_turn>" is exported as a normal token, so the following check does not work
// need to fix convert script
//vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL &&
(t.first == "<|eot_id|>" ||
t.first == "<|im_emd|>" ||
t.first == "<end_of_turn>"
)
) {
vocab.special_eot_id = t.second;
break;
}
}
}
}
// build special tokens cache
@ -4561,14 +4585,19 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str());
// special tokens
if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); }
if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); }
if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); }
if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); }
if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); }
if (vocab.special_cls_id != -1) { LLAMA_LOG_INFO( "%s: CLS token = %d '%s'\n", __func__, vocab.special_cls_id, vocab.id_to_token[vocab.special_cls_id].text.c_str() ); }
if (vocab.special_mask_id != -1) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, vocab.special_mask_id, vocab.id_to_token[vocab.special_mask_id].text.c_str() ); }
if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); }
if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); }
if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); }
if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); }
if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); }
if (vocab.special_cls_id != -1) { LLAMA_LOG_INFO( "%s: CLS token = %d '%s'\n", __func__, vocab.special_cls_id, vocab.id_to_token[vocab.special_cls_id].text.c_str() ); }
if (vocab.special_mask_id != -1) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, vocab.special_mask_id, vocab.id_to_token[vocab.special_mask_id].text.c_str() ); }
if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
if (vocab.special_prefix_id != -1) { LLAMA_LOG_INFO( "%s: PRE token = %d '%s'\n", __func__, vocab.special_prefix_id, vocab.id_to_token[vocab.special_prefix_id].text.c_str() ); }
if (vocab.special_suffix_id != -1) { LLAMA_LOG_INFO( "%s: SUF token = %d '%s'\n", __func__, vocab.special_suffix_id, vocab.id_to_token[vocab.special_suffix_id].text.c_str() ); }
if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); }
}
// Returns false if cancelled by progress_callback
@ -13583,16 +13612,14 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
GGML_ASSERT(ctx);
const int64_t t_start_sample_us = ggml_time_us();
bool allow_eos = false;
bool allow_eog = false;
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
allow_eos = true;
allow_eog = true;
break;
}
}
const llama_token eos = llama_token_eos(&ctx->model);
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
candidates_decoded.reserve(candidates->size);
std::vector<llama_grammar_candidate> candidates_grammar;
@ -13600,9 +13627,10 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id;
const std::string piece = llama_token_to_str(ctx, id);
if (id == eos) {
if (!allow_eos) {
const std::string piece = llama_token_to_piece(ctx, id, false);
if (llama_token_is_eog(&ctx->model, id)) {
if (!allow_eog) {
candidates->data[i].logit = -INFINITY;
}
} else if (piece.empty() || piece[0] == 0) {
@ -13791,7 +13819,7 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
const int64_t t_start_sample_us = ggml_time_us();
if (token == llama_token_eos(&ctx->model)) {
if (llama_token_is_eog(&ctx->model, token)) {
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
return;
@ -13800,7 +13828,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
GGML_ASSERT(false);
}
const std::string piece = llama_token_to_str(ctx, token);
const std::string piece = llama_token_to_piece(ctx, token, false);
// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
@ -17196,6 +17224,13 @@ llama_token_type llama_token_get_type(const struct llama_model * model, llama_to
return model->vocab.id_to_token[token].type;
}
bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
return token != -1 && (
token == llama_token_eos(model) ||
token == llama_token_eot(model)
);
}
llama_token llama_token_bos(const struct llama_model * model) {
return model->vocab.special_bos_id;
}
@ -17273,12 +17308,11 @@ static std::string llama_decode_text(const std::string & text) {
}
// does not write null-terminator to buf
int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length) {
int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) {
if(OldBPETokenizerMode)
{
return llama_token_to_piece_old(model, token, buf, length);
}
if (0 <= token && token < llama_n_vocab(model)) {
switch (llama_vocab_get_type(model->vocab)) {
case LLAMA_VOCAB_TYPE_WPM:
@ -17293,7 +17327,9 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (llama_is_user_defined_token(model->vocab, token)) {
} else if (
(llama_is_user_defined_token(model->vocab, token)) ||
(llama_is_control_token (model->vocab, token) && special)) {
std::string result = model->vocab.id_to_token[token].text;
if (length < (int) result.length()) {
return -(int) result.length();
@ -17306,8 +17342,6 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
}
memcpy(buf, "\xe2\x96\x85", 3);
return 3;
} else if (llama_is_control_token(model->vocab, token)) {
;
} else if (llama_is_byte_token(model->vocab, token)) {
if (length < 1) {
return -1;
@ -17328,15 +17362,15 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (llama_is_user_defined_token(model->vocab, token)) {
} else if (
(llama_is_user_defined_token(model->vocab, token)) ||
(llama_is_control_token (model->vocab, token) && special)) {
std::string result = model->vocab.id_to_token[token].text;
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (llama_is_control_token(model->vocab, token)) {
;
}
break;
}
@ -17534,6 +17568,15 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>";
}
} else if (tmpl == "llama3" || (tmpl.find("<|start_header_id|>") != std::string::npos && tmpl.find("<|end_header_id|>") != std::string::npos)) {
// Llama 3
for (auto message : chat) {
std::string role(message->role);
ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n" << trim(message->content) << "<|eot_id|>";
}
if (add_ass) {
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
}
} else {
// template not supported
return -1;