mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
wip on rewind function
This commit is contained in:
parent
7dac9982f9
commit
3e8bb10e2d
2 changed files with 60 additions and 2 deletions
|
@ -314,7 +314,7 @@ static std::string get_tok_vec_str(std::vector<int> &embd)
|
||||||
}
|
}
|
||||||
static void print_tok_vec_str(std::vector<int> &vec)
|
static void print_tok_vec_str(std::vector<int> &vec)
|
||||||
{
|
{
|
||||||
printf("\n%s", get_tok_vec_str(vec).c_str());
|
printf("\n[%s]\n", get_tok_vec_str(vec).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool allExtendedUnicode(const std::string& str) {
|
bool allExtendedUnicode(const std::string& str) {
|
||||||
|
@ -401,6 +401,64 @@ static void GetOverlappingTokenSequences(const std::string& str, std::unordered_
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ContextRewind(std::vector<int> &embd, std::vector<int> ¤t_context_tokens, int &n_past, std::vector<int> &last_n_tokens, const int amount_rewind)
|
||||||
|
{
|
||||||
|
if(amount_rewind<=0 || current_context_tokens.size()==0)
|
||||||
|
{
|
||||||
|
return; //do nothing
|
||||||
|
}
|
||||||
|
if(embd.size()>1)
|
||||||
|
{
|
||||||
|
printf("\nWARNING: Don't use context rewind when in batch processing phase!\n");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
bool is_mamba = (file_format == FileFormat::GGUF_GENERIC && file_format_meta.model_architecture==GGUFArch::ARCH_MAMBA);
|
||||||
|
bool is_rwkv_new = (file_format == FileFormat::GGUF_GENERIC && file_format_meta.model_architecture==GGUFArch::ARCH_RWKV);
|
||||||
|
if(file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2 || is_mamba || is_rwkv_new)
|
||||||
|
{
|
||||||
|
printf("\nWARNING: RNN models do not support context rewind!\n");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (amount_rewind >= last_n_tokens.size())
|
||||||
|
{
|
||||||
|
last_n_tokens.clear();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
last_n_tokens.resize(last_n_tokens.size() - amount_rewind);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (amount_rewind >= current_context_tokens.size())
|
||||||
|
{
|
||||||
|
current_context_tokens.clear();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
current_context_tokens.resize(current_context_tokens.size() - amount_rewind);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (amount_rewind >= n_past)
|
||||||
|
{
|
||||||
|
n_past = 0;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
n_past -= amount_rewind;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (file_format == FileFormat::GGUF_GENERIC)
|
||||||
|
{
|
||||||
|
llama_kv_cache_seq_rm(llama_ctx_v4, 0, n_past, -1);
|
||||||
|
}
|
||||||
|
|
||||||
|
embd.clear();
|
||||||
|
if(current_context_tokens.size()>0)
|
||||||
|
{
|
||||||
|
embd.push_back(current_context_tokens[current_context_tokens.size()-1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// KCPP SAMPLING FUNCTIONS
|
// KCPP SAMPLING FUNCTIONS
|
||||||
void sample_softmax(llama_token_data_array * cur_p) {
|
void sample_softmax(llama_token_data_array * cur_p) {
|
||||||
GGML_ASSERT(cur_p->size > 0);
|
GGML_ASSERT(cur_p->size > 0);
|
||||||
|
|
|
@ -3285,10 +3285,10 @@ def show_gui():
|
||||||
def load_config_gui(): #this is used to populate the GUI with a config file, whereas load_config_cli simply overwrites cli args
|
def load_config_gui(): #this is used to populate the GUI with a config file, whereas load_config_cli simply overwrites cli args
|
||||||
file_type = [("KoboldCpp Settings", "*.kcpps *.kcppt")]
|
file_type = [("KoboldCpp Settings", "*.kcpps *.kcppt")]
|
||||||
global runmode_untouched
|
global runmode_untouched
|
||||||
runmode_untouched = False
|
|
||||||
filename = askopenfilename(filetypes=file_type, defaultextension=file_type, initialdir=None)
|
filename = askopenfilename(filetypes=file_type, defaultextension=file_type, initialdir=None)
|
||||||
if not filename or filename=="":
|
if not filename or filename=="":
|
||||||
return
|
return
|
||||||
|
runmode_untouched = False
|
||||||
with open(filename, 'r') as f:
|
with open(filename, 'r') as f:
|
||||||
dict = json.load(f)
|
dict = json.load(f)
|
||||||
import_vars(dict)
|
import_vars(dict)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue