mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
fix to allow all EOGs to trigger a stop, occam's glm4 fix,
This commit is contained in:
parent
bd7a40f326
commit
f97bbdde00
6 changed files with 54 additions and 22 deletions
|
@ -301,14 +301,29 @@ static int GetEosID(FileFormat file_format, int32_t n_vocab)
|
|||
}
|
||||
return eosID;
|
||||
}
|
||||
static int GetEotID(FileFormat file_format)
|
||||
|
||||
static std::vector<int> GetEogIDs(FileFormat file_format, int32_t n_vocab)
|
||||
{
|
||||
std::vector<int> alleogs;
|
||||
int eos = GetEosID(file_format, n_vocab);
|
||||
if(file_format == FileFormat::GGUF_GENERIC)
|
||||
{
|
||||
const llama_vocab * tmpvocab = llama_model_get_vocab(llama_get_model(llama_ctx_v4));
|
||||
return llama_vocab_eot(tmpvocab);
|
||||
int eot = llama_vocab_eot(tmpvocab);
|
||||
std::set<int> eogs = tmpvocab->get_eogs();
|
||||
if (eot >= 0) {
|
||||
eogs.insert(eot);
|
||||
}
|
||||
if (eos >= 0) {
|
||||
eogs.insert(eos);
|
||||
}
|
||||
alleogs = std::vector<int>(eogs.begin(), eogs.end());
|
||||
} else {
|
||||
if (eos >= 0) {
|
||||
alleogs.push_back(eos);
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
return alleogs;
|
||||
}
|
||||
|
||||
static float LowestLogit(const std::vector<float> & logits)
|
||||
|
@ -1550,8 +1565,7 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
|
|||
}
|
||||
}
|
||||
|
||||
const llama_token eos = GetEosID(file_format,n_vocab);
|
||||
const llama_token eot = GetEotID(file_format);
|
||||
const std::vector<llama_token> eog_tokens = GetEogIDs(file_format,n_vocab);
|
||||
|
||||
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
|
||||
std::vector<llama_grammar_candidate> candidates_grammar;
|
||||
|
@ -1559,7 +1573,8 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
|
|||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
const llama_token id = candidates->data[i].id;
|
||||
const std::string piece = FileFormatTokenizeID(id,file_format);
|
||||
if (id == eos || (id==eot && id!=-1)) {
|
||||
bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), id) != eog_tokens.end();
|
||||
if (found_eog) {
|
||||
if (!allow_eos) {
|
||||
candidates->data[i].logit = -INFINITY;
|
||||
}
|
||||
|
@ -1711,7 +1726,9 @@ const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dyna
|
|||
|
||||
static void grammar_accept_token(FileFormat file_format, int32_t n_vocab, struct llama_grammar * grammar, llama_token token)
|
||||
{
|
||||
if (token == GetEosID(file_format,n_vocab) || (token!=-1 && token == GetEotID(file_format))) {
|
||||
const std::vector<llama_token> eog_tokens = GetEogIDs(file_format,n_vocab);
|
||||
bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), token) != eog_tokens.end();
|
||||
if (found_eog) {
|
||||
for (const auto & stack : grammar->stacks) {
|
||||
if (stack.empty()) {
|
||||
return;
|
||||
|
@ -3827,8 +3844,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
}
|
||||
}
|
||||
|
||||
unsigned int eosID = GetEosID(file_format, n_vocab);
|
||||
unsigned int eotID = GetEotID(file_format);
|
||||
const std::vector<llama_token> eog_tokens = GetEogIDs(file_format,n_vocab);
|
||||
float * logitsPtr;
|
||||
float lowestLogit = 0;
|
||||
int btsize = banned_token_ids.size();
|
||||
|
@ -3886,13 +3902,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
if (!inputs.allow_eos_token && !inputs.bypass_eos_token)
|
||||
{
|
||||
// set the logit of the eos token to very low to avoid sampling it
|
||||
if(eosID!=LLAMA_TOKEN_NULL)
|
||||
for(int i=0;i<eog_tokens.size();++i)
|
||||
{
|
||||
logitsPtr[eosID] = lowestLogit;
|
||||
}
|
||||
if(eotID!=-1)
|
||||
{
|
||||
logitsPtr[eotID] = lowestLogit;
|
||||
logitsPtr[eog_tokens[i]] = lowestLogit;
|
||||
}
|
||||
}
|
||||
if(btsize>0)
|
||||
|
@ -3958,7 +3970,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
for (auto eid : embd)
|
||||
{
|
||||
std::string tokenizedstr = FileFormatTokenizeID(eid, file_format, inputs.render_special);
|
||||
if(!inputs.render_special && (eid==eosID || (eid==eotID && eid!=-1) || VecContainsIntVal(special_stop_sequence,id))) //extra filter to avoid unwanted special tokens
|
||||
bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), eid) != eog_tokens.end();
|
||||
if(!inputs.render_special && (found_eog || VecContainsIntVal(special_stop_sequence,id))) //extra filter to avoid unwanted special tokens
|
||||
{
|
||||
tokenizedstr = ""; //prevent render
|
||||
}
|
||||
|
@ -4059,7 +4072,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
|
||||
if(!early_abort)
|
||||
{
|
||||
if(!inputs.bypass_eos_token && inputs.allow_eos_token && (id==eosID || (id==eotID && id!=-1)))
|
||||
bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), id) != eog_tokens.end();
|
||||
if(!inputs.bypass_eos_token && inputs.allow_eos_token && found_eog)
|
||||
{
|
||||
if(allow_regular_prints)
|
||||
{
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdbool.h>
|
||||
#include <set>
|
||||
|
||||
#ifdef LLAMA_SHARED
|
||||
# if defined(_WIN32) && !defined(__MINGW32__)
|
||||
|
@ -941,6 +942,8 @@ extern "C" {
|
|||
LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
|
||||
LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
|
||||
|
||||
LLAMA_API std::set<int> llama_vocab_get_eogs(const struct llama_vocab * vocab);
|
||||
|
||||
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
|
||||
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ logit_bias_max = 512
|
|||
dry_seq_break_max = 128
|
||||
|
||||
# global vars
|
||||
KcppVersion = "1.92"
|
||||
KcppVersion = "1.92.1"
|
||||
showdebug = True
|
||||
kcpp_instance = None #global running instance
|
||||
global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False}
|
||||
|
|
|
@ -1287,6 +1287,10 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
|
||||
if (wo) {
|
||||
cur = build_lora_mm(wo, cur);
|
||||
if (arch == LLM_ARCH_GLM4) {
|
||||
// GLM4 seems to have numerical issues with half-precision accumulators
|
||||
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
}
|
||||
|
||||
if (wo_b) {
|
||||
|
@ -1367,10 +1371,6 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
|
||||
if (wo) {
|
||||
cur = build_lora_mm(wo, cur);
|
||||
if (arch == LLM_ARCH_GLM4) {
|
||||
// GLM4 seems to have numerical issues with half-precision accumulators
|
||||
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
}
|
||||
|
||||
if (wo_b) {
|
||||
|
|
|
@ -1538,6 +1538,7 @@ struct llama_vocab::impl {
|
|||
bool is_user_defined(llama_token id) const;
|
||||
bool is_unused (llama_token id) const;
|
||||
bool is_eog (llama_token id) const;
|
||||
std::set<int> get_eogs() const;
|
||||
|
||||
uint8_t token_to_byte(llama_token id) const;
|
||||
|
||||
|
@ -2396,6 +2397,10 @@ bool llama_vocab::impl::is_eog(llama_token id) const {
|
|||
return id != LLAMA_TOKEN_NULL && special_eog_ids.count(id) > 0;
|
||||
}
|
||||
|
||||
std::set<int> llama_vocab::impl::get_eogs() const {
|
||||
return special_eog_ids;
|
||||
}
|
||||
|
||||
uint8_t llama_vocab::impl::token_to_byte(llama_token id) const {
|
||||
GGML_ASSERT(get_type() != LLAMA_VOCAB_TYPE_NONE);
|
||||
GGML_ASSERT(is_byte(id));
|
||||
|
@ -3121,6 +3126,10 @@ bool llama_vocab::is_eog(llama_token id) const {
|
|||
return pimpl->is_eog(id);
|
||||
}
|
||||
|
||||
std::set<int> llama_vocab::get_eogs() const {
|
||||
return pimpl->get_eogs();
|
||||
}
|
||||
|
||||
uint8_t llama_vocab::token_to_byte(llama_token id) const {
|
||||
return pimpl->token_to_byte(id);
|
||||
}
|
||||
|
@ -3431,6 +3440,11 @@ llama_token llama_vocab_eot(const struct llama_vocab * vocab) {
|
|||
return vocab->token_eot();
|
||||
}
|
||||
|
||||
std::set<int> llama_vocab_get_eogs(const struct llama_vocab * vocab)
|
||||
{
|
||||
return vocab->get_eogs();
|
||||
}
|
||||
|
||||
// deprecated
|
||||
llama_token llama_vocab_cls(const struct llama_vocab * vocab) {
|
||||
return vocab->token_bos();
|
||||
|
|
|
@ -40,6 +40,7 @@ struct llama_vocab {
|
|||
bool is_user_defined(llama_token id) const;
|
||||
bool is_unused (llama_token id) const;
|
||||
bool is_eog (llama_token id) const;
|
||||
std::set<int> get_eogs() const;
|
||||
|
||||
uint8_t token_to_byte(llama_token id) const;
|
||||
llama_token byte_to_token(uint8_t ch) const;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue