fix to allow all EOGs to trigger a stop, occam's glm4 fix,

This commit is contained in:
Concedo 2025-05-24 22:55:11 +08:00
parent bd7a40f326
commit f97bbdde00
6 changed files with 54 additions and 22 deletions

View file

@ -301,14 +301,29 @@ static int GetEosID(FileFormat file_format, int32_t n_vocab)
}
return eosID;
}
static int GetEotID(FileFormat file_format)
static std::vector<int> GetEogIDs(FileFormat file_format, int32_t n_vocab)
{
std::vector<int> alleogs;
int eos = GetEosID(file_format, n_vocab);
if(file_format == FileFormat::GGUF_GENERIC)
{
const llama_vocab * tmpvocab = llama_model_get_vocab(llama_get_model(llama_ctx_v4));
return llama_vocab_eot(tmpvocab);
int eot = llama_vocab_eot(tmpvocab);
std::set<int> eogs = tmpvocab->get_eogs();
if (eot >= 0) {
eogs.insert(eot);
}
if (eos >= 0) {
eogs.insert(eos);
}
alleogs = std::vector<int>(eogs.begin(), eogs.end());
} else {
if (eos >= 0) {
alleogs.push_back(eos);
}
}
return -1;
return alleogs;
}
static float LowestLogit(const std::vector<float> & logits)
@ -1550,8 +1565,7 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
}
}
const llama_token eos = GetEosID(file_format,n_vocab);
const llama_token eot = GetEotID(file_format);
const std::vector<llama_token> eog_tokens = GetEogIDs(file_format,n_vocab);
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
std::vector<llama_grammar_candidate> candidates_grammar;
@ -1559,7 +1573,8 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id;
const std::string piece = FileFormatTokenizeID(id,file_format);
if (id == eos || (id==eot && id!=-1)) {
bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), id) != eog_tokens.end();
if (found_eog) {
if (!allow_eos) {
candidates->data[i].logit = -INFINITY;
}
@ -1711,7 +1726,9 @@ const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dyna
static void grammar_accept_token(FileFormat file_format, int32_t n_vocab, struct llama_grammar * grammar, llama_token token)
{
if (token == GetEosID(file_format,n_vocab) || (token!=-1 && token == GetEotID(file_format))) {
const std::vector<llama_token> eog_tokens = GetEogIDs(file_format,n_vocab);
bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), token) != eog_tokens.end();
if (found_eog) {
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
return;
@ -3827,8 +3844,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
}
unsigned int eosID = GetEosID(file_format, n_vocab);
unsigned int eotID = GetEotID(file_format);
const std::vector<llama_token> eog_tokens = GetEogIDs(file_format,n_vocab);
float * logitsPtr;
float lowestLogit = 0;
int btsize = banned_token_ids.size();
@ -3886,13 +3902,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
if (!inputs.allow_eos_token && !inputs.bypass_eos_token)
{
// set the logit of the eos token to very low to avoid sampling it
if(eosID!=LLAMA_TOKEN_NULL)
for(int i=0;i<eog_tokens.size();++i)
{
logitsPtr[eosID] = lowestLogit;
}
if(eotID!=-1)
{
logitsPtr[eotID] = lowestLogit;
logitsPtr[eog_tokens[i]] = lowestLogit;
}
}
if(btsize>0)
@ -3958,7 +3970,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
for (auto eid : embd)
{
std::string tokenizedstr = FileFormatTokenizeID(eid, file_format, inputs.render_special);
if(!inputs.render_special && (eid==eosID || (eid==eotID && eid!=-1) || VecContainsIntVal(special_stop_sequence,id))) //extra filter to avoid unwanted special tokens
bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), eid) != eog_tokens.end();
if(!inputs.render_special && (found_eog || VecContainsIntVal(special_stop_sequence,id))) //extra filter to avoid unwanted special tokens
{
tokenizedstr = ""; //prevent render
}
@ -4059,7 +4072,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
if(!early_abort)
{
if(!inputs.bypass_eos_token && inputs.allow_eos_token && (id==eosID || (id==eotID && id!=-1)))
bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), id) != eog_tokens.end();
if(!inputs.bypass_eos_token && inputs.allow_eos_token && found_eog)
{
if(allow_regular_prints)
{

View file

@ -12,6 +12,7 @@
#include <stdint.h>
#include <stdio.h>
#include <stdbool.h>
#include <set>
#ifdef LLAMA_SHARED
# if defined(_WIN32) && !defined(__MINGW32__)
@ -941,6 +942,8 @@ extern "C" {
LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
LLAMA_API std::set<int> llama_vocab_get_eogs(const struct llama_vocab * vocab);
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);

View file

@ -52,7 +52,7 @@ logit_bias_max = 512
dry_seq_break_max = 128
# global vars
KcppVersion = "1.92"
KcppVersion = "1.92.1"
showdebug = True
kcpp_instance = None #global running instance
global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False}

View file

@ -1287,6 +1287,10 @@ ggml_tensor * llm_graph_context::build_attn(
if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4) {
// GLM4 seems to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}
if (wo_b) {
@ -1367,10 +1371,6 @@ ggml_tensor * llm_graph_context::build_attn(
if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4) {
// GLM4 seems to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}
if (wo_b) {

View file

@ -1538,6 +1538,7 @@ struct llama_vocab::impl {
bool is_user_defined(llama_token id) const;
bool is_unused (llama_token id) const;
bool is_eog (llama_token id) const;
std::set<int> get_eogs() const;
uint8_t token_to_byte(llama_token id) const;
@ -2396,6 +2397,10 @@ bool llama_vocab::impl::is_eog(llama_token id) const {
return id != LLAMA_TOKEN_NULL && special_eog_ids.count(id) > 0;
}
std::set<int> llama_vocab::impl::get_eogs() const {
return special_eog_ids;
}
uint8_t llama_vocab::impl::token_to_byte(llama_token id) const {
GGML_ASSERT(get_type() != LLAMA_VOCAB_TYPE_NONE);
GGML_ASSERT(is_byte(id));
@ -3121,6 +3126,10 @@ bool llama_vocab::is_eog(llama_token id) const {
return pimpl->is_eog(id);
}
std::set<int> llama_vocab::get_eogs() const {
return pimpl->get_eogs();
}
uint8_t llama_vocab::token_to_byte(llama_token id) const {
return pimpl->token_to_byte(id);
}
@ -3431,6 +3440,11 @@ llama_token llama_vocab_eot(const struct llama_vocab * vocab) {
return vocab->token_eot();
}
std::set<int> llama_vocab_get_eogs(const struct llama_vocab * vocab)
{
return vocab->get_eogs();
}
// deprecated
llama_token llama_vocab_cls(const struct llama_vocab * vocab) {
return vocab->token_bos();

View file

@ -40,6 +40,7 @@ struct llama_vocab {
bool is_user_defined(llama_token id) const;
bool is_unused (llama_token id) const;
bool is_eog (llama_token id) const;
std::set<int> get_eogs() const;
uint8_t token_to_byte(llama_token id) const;
llama_token byte_to_token(uint8_t ch) const;