mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
wip on rewind function
This commit is contained in:
parent
7dac9982f9
commit
3e8bb10e2d
2 changed files with 60 additions and 2 deletions
|
@ -314,7 +314,7 @@ static std::string get_tok_vec_str(std::vector<int> &embd)
|
|||
}
|
||||
static void print_tok_vec_str(std::vector<int> &vec)
|
||||
{
|
||||
printf("\n%s", get_tok_vec_str(vec).c_str());
|
||||
printf("\n[%s]\n", get_tok_vec_str(vec).c_str());
|
||||
}
|
||||
|
||||
bool allExtendedUnicode(const std::string& str) {
|
||||
|
@ -401,6 +401,64 @@ static void GetOverlappingTokenSequences(const std::string& str, std::unordered_
|
|||
}
|
||||
}
|
||||
|
||||
void ContextRewind(std::vector<int> &embd, std::vector<int> ¤t_context_tokens, int &n_past, std::vector<int> &last_n_tokens, const int amount_rewind)
|
||||
{
|
||||
if(amount_rewind<=0 || current_context_tokens.size()==0)
|
||||
{
|
||||
return; //do nothing
|
||||
}
|
||||
if(embd.size()>1)
|
||||
{
|
||||
printf("\nWARNING: Don't use context rewind when in batch processing phase!\n");
|
||||
return;
|
||||
}
|
||||
bool is_mamba = (file_format == FileFormat::GGUF_GENERIC && file_format_meta.model_architecture==GGUFArch::ARCH_MAMBA);
|
||||
bool is_rwkv_new = (file_format == FileFormat::GGUF_GENERIC && file_format_meta.model_architecture==GGUFArch::ARCH_RWKV);
|
||||
if(file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2 || is_mamba || is_rwkv_new)
|
||||
{
|
||||
printf("\nWARNING: RNN models do not support context rewind!\n");
|
||||
return;
|
||||
}
|
||||
|
||||
if (amount_rewind >= last_n_tokens.size())
|
||||
{
|
||||
last_n_tokens.clear();
|
||||
}
|
||||
else
|
||||
{
|
||||
last_n_tokens.resize(last_n_tokens.size() - amount_rewind);
|
||||
}
|
||||
|
||||
if (amount_rewind >= current_context_tokens.size())
|
||||
{
|
||||
current_context_tokens.clear();
|
||||
}
|
||||
else
|
||||
{
|
||||
current_context_tokens.resize(current_context_tokens.size() - amount_rewind);
|
||||
}
|
||||
|
||||
if (amount_rewind >= n_past)
|
||||
{
|
||||
n_past = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
n_past -= amount_rewind;
|
||||
}
|
||||
|
||||
if (file_format == FileFormat::GGUF_GENERIC)
|
||||
{
|
||||
llama_kv_cache_seq_rm(llama_ctx_v4, 0, n_past, -1);
|
||||
}
|
||||
|
||||
embd.clear();
|
||||
if(current_context_tokens.size()>0)
|
||||
{
|
||||
embd.push_back(current_context_tokens[current_context_tokens.size()-1]);
|
||||
}
|
||||
}
|
||||
|
||||
// KCPP SAMPLING FUNCTIONS
|
||||
void sample_softmax(llama_token_data_array * cur_p) {
|
||||
GGML_ASSERT(cur_p->size > 0);
|
||||
|
|
|
@ -3285,10 +3285,10 @@ def show_gui():
|
|||
def load_config_gui(): #this is used to populate the GUI with a config file, whereas load_config_cli simply overwrites cli args
|
||||
file_type = [("KoboldCpp Settings", "*.kcpps *.kcppt")]
|
||||
global runmode_untouched
|
||||
runmode_untouched = False
|
||||
filename = askopenfilename(filetypes=file_type, defaultextension=file_type, initialdir=None)
|
||||
if not filename or filename=="":
|
||||
return
|
||||
runmode_untouched = False
|
||||
with open(filename, 'r') as f:
|
||||
dict = json.load(f)
|
||||
import_vars(dict)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue