fix to allow all EOGs to trigger a stop, occam's glm4 fix,

This commit is contained in:
Concedo 2025-05-24 22:55:11 +08:00
parent bd7a40f326
commit f97bbdde00
6 changed files with 54 additions and 22 deletions

View file

@ -301,14 +301,29 @@ static int GetEosID(FileFormat file_format, int32_t n_vocab)
} }
return eosID; return eosID;
} }
static int GetEotID(FileFormat file_format)
static std::vector<int> GetEogIDs(FileFormat file_format, int32_t n_vocab)
{ {
std::vector<int> alleogs;
int eos = GetEosID(file_format, n_vocab);
if(file_format == FileFormat::GGUF_GENERIC) if(file_format == FileFormat::GGUF_GENERIC)
{ {
const llama_vocab * tmpvocab = llama_model_get_vocab(llama_get_model(llama_ctx_v4)); const llama_vocab * tmpvocab = llama_model_get_vocab(llama_get_model(llama_ctx_v4));
return llama_vocab_eot(tmpvocab); int eot = llama_vocab_eot(tmpvocab);
std::set<int> eogs = tmpvocab->get_eogs();
if (eot >= 0) {
eogs.insert(eot);
}
if (eos >= 0) {
eogs.insert(eos);
}
alleogs = std::vector<int>(eogs.begin(), eogs.end());
} else {
if (eos >= 0) {
alleogs.push_back(eos);
}
} }
return -1; return alleogs;
} }
static float LowestLogit(const std::vector<float> & logits) static float LowestLogit(const std::vector<float> & logits)
@ -1550,8 +1565,7 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
} }
} }
const llama_token eos = GetEosID(file_format,n_vocab); const std::vector<llama_token> eog_tokens = GetEogIDs(file_format,n_vocab);
const llama_token eot = GetEotID(file_format);
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded; std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
std::vector<llama_grammar_candidate> candidates_grammar; std::vector<llama_grammar_candidate> candidates_grammar;
@ -1559,7 +1573,8 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
for (size_t i = 0; i < candidates->size; ++i) { for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id; const llama_token id = candidates->data[i].id;
const std::string piece = FileFormatTokenizeID(id,file_format); const std::string piece = FileFormatTokenizeID(id,file_format);
if (id == eos || (id==eot && id!=-1)) { bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), id) != eog_tokens.end();
if (found_eog) {
if (!allow_eos) { if (!allow_eos) {
candidates->data[i].logit = -INFINITY; candidates->data[i].logit = -INFINITY;
} }
@ -1711,7 +1726,9 @@ const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dyna
static void grammar_accept_token(FileFormat file_format, int32_t n_vocab, struct llama_grammar * grammar, llama_token token) static void grammar_accept_token(FileFormat file_format, int32_t n_vocab, struct llama_grammar * grammar, llama_token token)
{ {
if (token == GetEosID(file_format,n_vocab) || (token!=-1 && token == GetEotID(file_format))) { const std::vector<llama_token> eog_tokens = GetEogIDs(file_format,n_vocab);
bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), token) != eog_tokens.end();
if (found_eog) {
for (const auto & stack : grammar->stacks) { for (const auto & stack : grammar->stacks) {
if (stack.empty()) { if (stack.empty()) {
return; return;
@ -3827,8 +3844,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
} }
} }
unsigned int eosID = GetEosID(file_format, n_vocab); const std::vector<llama_token> eog_tokens = GetEogIDs(file_format,n_vocab);
unsigned int eotID = GetEotID(file_format);
float * logitsPtr; float * logitsPtr;
float lowestLogit = 0; float lowestLogit = 0;
int btsize = banned_token_ids.size(); int btsize = banned_token_ids.size();
@ -3886,13 +3902,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
if (!inputs.allow_eos_token && !inputs.bypass_eos_token) if (!inputs.allow_eos_token && !inputs.bypass_eos_token)
{ {
// set the logit of the eos token to very low to avoid sampling it // set the logit of the eos token to very low to avoid sampling it
if(eosID!=LLAMA_TOKEN_NULL) for(int i=0;i<eog_tokens.size();++i)
{ {
logitsPtr[eosID] = lowestLogit; logitsPtr[eog_tokens[i]] = lowestLogit;
}
if(eotID!=-1)
{
logitsPtr[eotID] = lowestLogit;
} }
} }
if(btsize>0) if(btsize>0)
@ -3958,7 +3970,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
for (auto eid : embd) for (auto eid : embd)
{ {
std::string tokenizedstr = FileFormatTokenizeID(eid, file_format, inputs.render_special); std::string tokenizedstr = FileFormatTokenizeID(eid, file_format, inputs.render_special);
if(!inputs.render_special && (eid==eosID || (eid==eotID && eid!=-1) || VecContainsIntVal(special_stop_sequence,id))) //extra filter to avoid unwanted special tokens bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), eid) != eog_tokens.end();
if(!inputs.render_special && (found_eog || VecContainsIntVal(special_stop_sequence,id))) //extra filter to avoid unwanted special tokens
{ {
tokenizedstr = ""; //prevent render tokenizedstr = ""; //prevent render
} }
@ -4059,7 +4072,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
if(!early_abort) if(!early_abort)
{ {
if(!inputs.bypass_eos_token && inputs.allow_eos_token && (id==eosID || (id==eotID && id!=-1))) bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), id) != eog_tokens.end();
if(!inputs.bypass_eos_token && inputs.allow_eos_token && found_eog)
{ {
if(allow_regular_prints) if(allow_regular_prints)
{ {

View file

@ -12,6 +12,7 @@
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
#include <stdbool.h> #include <stdbool.h>
#include <set>
#ifdef LLAMA_SHARED #ifdef LLAMA_SHARED
# if defined(_WIN32) && !defined(__MINGW32__) # if defined(_WIN32) && !defined(__MINGW32__)
@ -941,6 +942,8 @@ extern "C" {
LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
LLAMA_API std::set<int> llama_vocab_get_eogs(const struct llama_vocab * vocab);
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab); LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab); LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);

View file

@ -52,7 +52,7 @@ logit_bias_max = 512
dry_seq_break_max = 128 dry_seq_break_max = 128
# global vars # global vars
KcppVersion = "1.92" KcppVersion = "1.92.1"
showdebug = True showdebug = True
kcpp_instance = None #global running instance kcpp_instance = None #global running instance
global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False} global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False}

View file

@ -1287,6 +1287,10 @@ ggml_tensor * llm_graph_context::build_attn(
if (wo) { if (wo) {
cur = build_lora_mm(wo, cur); cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4) {
// GLM4 seems to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
} }
if (wo_b) { if (wo_b) {
@ -1367,10 +1371,6 @@ ggml_tensor * llm_graph_context::build_attn(
if (wo) { if (wo) {
cur = build_lora_mm(wo, cur); cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4) {
// GLM4 seems to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
} }
if (wo_b) { if (wo_b) {

View file

@ -1538,6 +1538,7 @@ struct llama_vocab::impl {
bool is_user_defined(llama_token id) const; bool is_user_defined(llama_token id) const;
bool is_unused (llama_token id) const; bool is_unused (llama_token id) const;
bool is_eog (llama_token id) const; bool is_eog (llama_token id) const;
std::set<int> get_eogs() const;
uint8_t token_to_byte(llama_token id) const; uint8_t token_to_byte(llama_token id) const;
@ -2396,6 +2397,10 @@ bool llama_vocab::impl::is_eog(llama_token id) const {
return id != LLAMA_TOKEN_NULL && special_eog_ids.count(id) > 0; return id != LLAMA_TOKEN_NULL && special_eog_ids.count(id) > 0;
} }
std::set<int> llama_vocab::impl::get_eogs() const {
return special_eog_ids;
}
uint8_t llama_vocab::impl::token_to_byte(llama_token id) const { uint8_t llama_vocab::impl::token_to_byte(llama_token id) const {
GGML_ASSERT(get_type() != LLAMA_VOCAB_TYPE_NONE); GGML_ASSERT(get_type() != LLAMA_VOCAB_TYPE_NONE);
GGML_ASSERT(is_byte(id)); GGML_ASSERT(is_byte(id));
@ -3121,6 +3126,10 @@ bool llama_vocab::is_eog(llama_token id) const {
return pimpl->is_eog(id); return pimpl->is_eog(id);
} }
std::set<int> llama_vocab::get_eogs() const {
return pimpl->get_eogs();
}
uint8_t llama_vocab::token_to_byte(llama_token id) const { uint8_t llama_vocab::token_to_byte(llama_token id) const {
return pimpl->token_to_byte(id); return pimpl->token_to_byte(id);
} }
@ -3431,6 +3440,11 @@ llama_token llama_vocab_eot(const struct llama_vocab * vocab) {
return vocab->token_eot(); return vocab->token_eot();
} }
std::set<int> llama_vocab_get_eogs(const struct llama_vocab * vocab)
{
return vocab->get_eogs();
}
// deprecated // deprecated
llama_token llama_vocab_cls(const struct llama_vocab * vocab) { llama_token llama_vocab_cls(const struct llama_vocab * vocab) {
return vocab->token_bos(); return vocab->token_bos();

View file

@ -40,6 +40,7 @@ struct llama_vocab {
bool is_user_defined(llama_token id) const; bool is_user_defined(llama_token id) const;
bool is_unused (llama_token id) const; bool is_unused (llama_token id) const;
bool is_eog (llama_token id) const; bool is_eog (llama_token id) const;
std::set<int> get_eogs() const;
uint8_t token_to_byte(llama_token id) const; uint8_t token_to_byte(llama_token id) const;
llama_token byte_to_token(uint8_t ch) const; llama_token byte_to_token(uint8_t ch) const;