wip on rewind function

This commit is contained in:
Concedo 2024-10-06 16:21:03 +08:00
parent 7dac9982f9
commit 3e8bb10e2d
2 changed files with 60 additions and 2 deletions

View file

@ -314,7 +314,7 @@ static std::string get_tok_vec_str(std::vector<int> &embd)
} }
static void print_tok_vec_str(std::vector<int> &vec) static void print_tok_vec_str(std::vector<int> &vec)
{ {
printf("\n%s", get_tok_vec_str(vec).c_str()); printf("\n[%s]\n", get_tok_vec_str(vec).c_str());
} }
bool allExtendedUnicode(const std::string& str) { bool allExtendedUnicode(const std::string& str) {
@ -401,6 +401,64 @@ static void GetOverlappingTokenSequences(const std::string& str, std::unordered_
} }
} }
void ContextRewind(std::vector<int> &embd, std::vector<int> &current_context_tokens, int &n_past, std::vector<int> &last_n_tokens, const int amount_rewind)
{
if(amount_rewind<=0 || current_context_tokens.size()==0)
{
return; //do nothing
}
if(embd.size()>1)
{
printf("\nWARNING: Don't use context rewind when in batch processing phase!\n");
return;
}
bool is_mamba = (file_format == FileFormat::GGUF_GENERIC && file_format_meta.model_architecture==GGUFArch::ARCH_MAMBA);
bool is_rwkv_new = (file_format == FileFormat::GGUF_GENERIC && file_format_meta.model_architecture==GGUFArch::ARCH_RWKV);
if(file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2 || is_mamba || is_rwkv_new)
{
printf("\nWARNING: RNN models do not support context rewind!\n");
return;
}
if (amount_rewind >= last_n_tokens.size())
{
last_n_tokens.clear();
}
else
{
last_n_tokens.resize(last_n_tokens.size() - amount_rewind);
}
if (amount_rewind >= current_context_tokens.size())
{
current_context_tokens.clear();
}
else
{
current_context_tokens.resize(current_context_tokens.size() - amount_rewind);
}
if (amount_rewind >= n_past)
{
n_past = 0;
}
else
{
n_past -= amount_rewind;
}
if (file_format == FileFormat::GGUF_GENERIC)
{
llama_kv_cache_seq_rm(llama_ctx_v4, 0, n_past, -1);
}
embd.clear();
if(current_context_tokens.size()>0)
{
embd.push_back(current_context_tokens[current_context_tokens.size()-1]);
}
}
// KCPP SAMPLING FUNCTIONS // KCPP SAMPLING FUNCTIONS
void sample_softmax(llama_token_data_array * cur_p) { void sample_softmax(llama_token_data_array * cur_p) {
GGML_ASSERT(cur_p->size > 0); GGML_ASSERT(cur_p->size > 0);

View file

@ -3285,10 +3285,10 @@ def show_gui():
def load_config_gui(): #this is used to populate the GUI with a config file, whereas load_config_cli simply overwrites cli args def load_config_gui(): #this is used to populate the GUI with a config file, whereas load_config_cli simply overwrites cli args
file_type = [("KoboldCpp Settings", "*.kcpps *.kcppt")] file_type = [("KoboldCpp Settings", "*.kcpps *.kcppt")]
global runmode_untouched global runmode_untouched
runmode_untouched = False
filename = askopenfilename(filetypes=file_type, defaultextension=file_type, initialdir=None) filename = askopenfilename(filetypes=file_type, defaultextension=file_type, initialdir=None)
if not filename or filename=="": if not filename or filename=="":
return return
runmode_untouched = False
with open(filename, 'r') as f: with open(filename, 'r') as f:
dict = json.load(f) dict = json.load(f)
import_vars(dict) import_vars(dict)