diff --git a/expose.h b/expose.h index 050613374..33dccd928 100644 --- a/expose.h +++ b/expose.h @@ -26,6 +26,8 @@ struct generation_inputs const float temperature; const int top_k; const float top_p; + const float typical_p; + const float tfs; const float rep_pen; const int rep_pen_range; const char * stop_sequence[stop_token_max]; diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index b3730ce50..23aa60e79 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -78,6 +78,59 @@ inline bool LogitsDuplicated(std::vector & arr1, std::vector & arr return true; } + +llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng) +{ + const int64_t t_start_sample_us = ggml_time_us(); + llama_sample_softmax(nullptr, candidates); + std::vector probs; + probs.reserve(candidates->size); + for (size_t i = 0; i < candidates->size; ++i) { + probs.push_back(candidates->data[i].p); + } + std::discrete_distribution<> dist(probs.begin(), probs.end()); + int idx = dist(rng); + llama_token result = candidates->data[idx].id; + return result; +} + +int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_p, float typical_p, float tfs, float temp, std::mt19937 & rng) +{ + int id = 0; + std::vector candidates; + candidates.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + + // Apply penalties + auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx); + llama_sample_repetition_penalty(nullptr, &candidates_p, + last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + last_n_repeat, rep_pen); + + // llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p, + // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + // last_n_repeat, alpha_frequency, alpha_presence); + + if (temp <= 0) { + // Greedy sampling + id = llama_sample_token_greedy(nullptr, &candidates_p); + } else { + // Temperature sampling + llama_sample_top_k(nullptr, &candidates_p, top_k); + llama_sample_tail_free(nullptr, &candidates_p, tfs); + llama_sample_typical(nullptr, &candidates_p, typical_p); + llama_sample_top_p(nullptr, &candidates_p, top_p); + llama_sample_temperature(nullptr, &candidates_p, temp); + id = sample_token(&candidates_p, rng); + } + + return id; +} + ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in_file_format) { ggml_time_init(); @@ -311,6 +364,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o params.n_predict = inputs.max_length; params.top_k = inputs.top_k; params.top_p = inputs.top_p; + params.typical_p = inputs.typical_p; + params.tfs_z = inputs.tfs; params.temp = inputs.temperature; params.repeat_last_n = inputs.rep_pen_range; params.repeat_penalty = inputs.rep_pen; @@ -423,7 +478,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT) { - //do nothing + n_vocab = llama_n_vocab(llama_ctx_v1); } else if (file_format == FileFormat::GPTJ_1 || file_format == FileFormat::GPTJ_2) { @@ -557,6 +612,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o const float top_p = params.top_p; const float temp = params.temp; const float repeat_penalty = params.repeat_penalty; + const float typical_p = params.typical_p; + const float tfs_z = params.tfs_z; if (!startedsampling) { @@ -581,7 +638,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o logits[29961] = 0; } - id = llama_sample_top_p_top_k(llama_ctx_v1, last_n_tokens.data(), last_n_tokens.size(), top_k, top_p, temp, repeat_penalty); + id = SampleLogits(logits, nctx, n_vocab, last_n_size, repeat_penalty, top_k, top_p, typical_p, tfs_z, temp, rng); } else @@ -601,7 +658,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o //gpt2 uses negative logits, so we cant zero it } - id = gptj_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng); + id = SampleLogits(logits.data(), nctx, n_vocab, last_n_size, repeat_penalty, top_k, top_p, typical_p, tfs_z, temp, rng); } last_n_tokens.erase(last_n_tokens.begin()); diff --git a/koboldcpp.py b/koboldcpp.py index 0567b8ce7..3a56fa8c9 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -32,6 +32,8 @@ class generation_inputs(ctypes.Structure): ("temperature", ctypes.c_float), ("top_k", ctypes.c_int), ("top_p", ctypes.c_float), + ("typical_p", ctypes.c_float), + ("tfs", ctypes.c_float), ("rep_pen", ctypes.c_float), ("rep_pen_range", ctypes.c_int), ("stop_sequence", ctypes.c_char_p * stop_token_max)] @@ -146,7 +148,7 @@ def load_model(model_filename): ret = handle.load_model(inputs) return ret -def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=100,top_p=0.85,rep_pen=1.1,rep_pen_range=128,seed=-1,stop_sequence=[]): +def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=100,top_p=0.85, typical_p=1.0, tfs=1.0 ,rep_pen=1.1,rep_pen_range=128,seed=-1,stop_sequence=[]): inputs = generation_inputs() outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs)) inputs.prompt = prompt.encode("UTF-8") @@ -155,6 +157,8 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k= inputs.temperature = temperature inputs.top_k = top_k inputs.top_p = top_p + inputs.typical_p = typical_p + inputs.tfs = tfs inputs.rep_pen = rep_pen inputs.rep_pen_range = rep_pen_range inputs.seed = seed @@ -297,6 +301,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): temperature=genparams.get('temperature', 0.8), top_k=genparams.get('top_k', 200), top_p=genparams.get('top_p', 0.85), + typical_p=genparams.get('typical', 1.0), + tfs=genparams.get('tfs', 1.0), rep_pen=genparams.get('rep_pen', 1.1), rep_pen_range=genparams.get('rep_pen_range', 128), seed=-1, @@ -311,6 +317,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): temperature=genparams.get('temperature', 0.8), top_k=genparams.get('top_k', 200), top_p=genparams.get('top_p', 0.85), + typical_p=genparams.get('typical', 1.0), + tfs=genparams.get('tfs', 1.0), rep_pen=genparams.get('rep_pen', 1.1), rep_pen_range=genparams.get('rep_pen_range', 128), seed=-1,