Quadratic Sampling UI (#652)

* Quadratic Sampling UI

Kalomaze's Quadratic Sampling, now has a UI within KCPP.

* remove debug prints

* cleanup, add smooth sampler to dynatemp

---------

Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com>
This commit is contained in:
Alexander Abushady 2024-02-04 03:26:27 -05:00 committed by GitHub
parent 504300784f
commit 4cb956c7db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 57 additions and 38 deletions

View file

@ -83,9 +83,7 @@ static int n_batch = 8;
static bool useSmartContext = false;
static bool useContextShift = false;
static int blasbatchsize = 512;
static int dontblasbatchsize = 16;
static int normalbatchsize = 32;
static int smallbatchsize = 8;
static int smallbatchsize = 16;
static int debugmode = 0; //-1 = hide all, 0 = normal, 1 = showall
static std::string modelname;
static std::vector<gpt_vocab::id> last_n_tokens;
@ -427,18 +425,18 @@ void sample_rep_pen(int n_ctx, int rep_pen_range, float rep_pen, float presence_
}
void sample_temperature(llama_token_data_array * candidates_p, float temp)
void sample_temperature(llama_token_data_array * candidates_p, float temp, float smoothing_factor)
{
if (temp <= 0)
{
// Imitate greedy sampling
temp = 0.00390625f; //cannot be zero else div0, this is 1/256
llama_sample_temperature(nullptr, candidates_p, temp);
llama_sample_temperature(nullptr, candidates_p, temp, 0);
llama_sample_top_k(nullptr, candidates_p, 1, 1); //only want first candidate
}
else
{
llama_sample_temperature(nullptr, candidates_p, temp);
llama_sample_temperature(nullptr, candidates_p, temp, smoothing_factor);
}
}
@ -482,7 +480,7 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
}
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float presence_penalty, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent)
int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor)
{
int id = 0;
std::vector<llama_token_data> candidates;
@ -508,7 +506,7 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
sample_rep_pen(n_ctx, rep_pen_range, rep_pen, presence_penalty, &candidates_p);
sample_temperature(&candidates_p, temp);
sample_temperature(&candidates_p, temp, smoothing_factor);
if (mirostat == 1)
{
id = sample_token_mirostat(n_vocab, &candidates_p, rng, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
@ -549,11 +547,11 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers
dynatemp_min = dynatemp_min<0?0:dynatemp_min;
dynatemp_max = dynatemp_max<0?0:dynatemp_max;
dynatemp_exponent = dynatemp_exponent<0?0:dynatemp_exponent;
llama_sample_entropy(nullptr, &candidates_p, dynatemp_min, dynatemp_max, dynatemp_exponent);
llama_sample_entropy(nullptr, &candidates_p, dynatemp_min, dynatemp_max, dynatemp_exponent, smoothing_factor);
}
else
{
sample_temperature(&candidates_p, temp);
sample_temperature(&candidates_p, temp, smoothing_factor);
}
break;
case KCPP_SAMPLER_REP_PEN:
@ -698,7 +696,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
n_blasthreads = kcpp_params->n_threads_batch = inputs.blasthreads;
bool isGguf = (file_format == FileFormat::GGUF_GENERIC);
n_batch = kcpp_params->n_batch = (isGguf?normalbatchsize:smallbatchsize);
n_batch = kcpp_params->n_batch = smallbatchsize;
modelname = kcpp_params->model = inputs.model_filename;
useSmartContext = inputs.use_smartcontext;
useContextShift = inputs.use_contextshift;
@ -706,7 +704,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
blasbatchsize = inputs.blasbatchsize;
if(blasbatchsize<=0)
{
blasbatchsize = (isGguf?dontblasbatchsize:smallbatchsize);
blasbatchsize = smallbatchsize;
}
auto clamped_max_context_length = inputs.max_context_length;
@ -1533,6 +1531,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
kcpp_params->n_batch = n_batch;
kcpp_params->n_threads = n_threads;
kcpp_params->n_threads_batch = n_blasthreads;
kcpp_params->smoothing_factor = inputs.smoothing_factor;
bool stream_sse = inputs.stream_sse;
@ -1675,7 +1674,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
file_format == FileFormat::GPTJ_2 ||
file_format == FileFormat::RWKV_1 ||
file_format==FileFormat::RWKV_2);
bool blasmode = (approved_format && embd_inp.size() >= 32 && ggml_cpu_has_blas() && blasbatchsize!=-1);
bool blasmode = (approved_format && embd_inp.size() >= 32 && ggml_cpu_has_blas() && blasbatchsize>=32);
// bool blasmode = false;
int original_batch = kcpp_params->n_batch;
int original_threads = kcpp_params->n_threads;
@ -1930,6 +1929,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
const float tfs_z = kcpp_params->tfs_z;
const float dynatemp_range = kcpp_params->dynatemp_range;
const float dynatemp_exponent = kcpp_params->dynatemp_exponent;
const float smoothing_factor = kcpp_params->smoothing_factor;
if (!startedsampling)
{
@ -1985,7 +1985,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, presence_penalty,
top_k, top_a, top_p, min_p, typical_p, tfs_z, temp, rng,
kcpp_params->mirostat, kcpp_params->mirostat_tau, kcpp_params->mirostat_eta, sampler_order, grammar, dynatemp_range, dynatemp_exponent);
kcpp_params->mirostat, kcpp_params->mirostat_tau, kcpp_params->mirostat_eta, sampler_order, grammar, dynatemp_range, dynatemp_exponent, smoothing_factor);
if (grammar != nullptr) {
grammar_accept_token(file_format, n_vocab, grammar, id);