mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 09:04:36 +00:00
integrated mirostat as a launch parameter, works on all models
This commit is contained in:
parent
851f55325a
commit
8a964e76c8
3 changed files with 102 additions and 26 deletions
|
@ -95,7 +95,62 @@ llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng
|
|||
return result;
|
||||
}
|
||||
|
||||
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_p, float typical_p, float tfs, float temp, std::mt19937 & rng)
|
||||
llama_token sample_token_mirostat(int n_vocab, llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int m, float * mu)
|
||||
{
|
||||
float N = float(n_vocab);
|
||||
llama_sample_softmax(nullptr, candidates);
|
||||
// Estimate s_hat using the most probable m tokens
|
||||
float s_hat = 0.0;
|
||||
float sum_ti_bi = 0.0;
|
||||
float sum_ti_sq = 0.0;
|
||||
for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
|
||||
float t_i = logf(float(i + 2) / float(i + 1));
|
||||
float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
|
||||
sum_ti_bi += t_i * b_i;
|
||||
sum_ti_sq += t_i * t_i;
|
||||
}
|
||||
s_hat = sum_ti_bi / sum_ti_sq;
|
||||
// Compute k from the estimated s_hat and target surprise value
|
||||
float epsilon_hat = s_hat - 1;
|
||||
float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat);
|
||||
// Sample the next word X using top-k sampling
|
||||
llama_sample_top_k(nullptr, candidates, int(k));
|
||||
llama_token X = sample_token(candidates, rng); // Compute error as the difference between observed surprise and target surprise value
|
||||
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
return candidate.id == X;
|
||||
}));
|
||||
float observed_surprise = -log2f(candidates->data[X_idx].p);
|
||||
float e = observed_surprise - tau;
|
||||
// Update mu using the learning rate and error
|
||||
*mu = *mu - eta * e;
|
||||
return X;
|
||||
}
|
||||
|
||||
llama_token sample_token_mirostat_v2(llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float * mu)
|
||||
{
|
||||
llama_sample_softmax(nullptr, candidates);
|
||||
// Truncate the words with surprise values greater than mu
|
||||
candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
return -log2f(candidate.p) > *mu;
|
||||
}));
|
||||
// Normalize the probabilities of the remaining words
|
||||
llama_sample_softmax(nullptr, candidates);
|
||||
// Sample the next word X from the remaining words
|
||||
llama_token X = sample_token(candidates,rng);
|
||||
|
||||
// Compute error as the difference between observed surprise and target surprise value
|
||||
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
return candidate.id == X;
|
||||
}));
|
||||
float observed_surprise = -log2f(candidates->data[X_idx].p);
|
||||
float e = observed_surprise - tau;
|
||||
// Update mu using the learning rate and error
|
||||
*mu = *mu - eta * e;
|
||||
return X;
|
||||
}
|
||||
|
||||
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
|
||||
int mirostat, float mirostat_tau, float mirostat_eta)
|
||||
{
|
||||
int id = 0;
|
||||
std::vector<llama_token_data> candidates;
|
||||
|
@ -115,18 +170,37 @@ int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range
|
|||
// llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p,
|
||||
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
||||
// last_n_repeat, alpha_frequency, alpha_presence);
|
||||
|
||||
if (temp <= 0) {
|
||||
|
||||
if (temp <= 0)
|
||||
{
|
||||
// Greedy sampling
|
||||
id = llama_sample_token_greedy(nullptr, &candidates_p);
|
||||
} else {
|
||||
// Temperature sampling
|
||||
llama_sample_top_k(nullptr, &candidates_p, top_k);
|
||||
llama_sample_tail_free(nullptr, &candidates_p, tfs);
|
||||
llama_sample_typical(nullptr, &candidates_p, typical_p);
|
||||
llama_sample_top_p(nullptr, &candidates_p, top_p);
|
||||
llama_sample_temperature(nullptr, &candidates_p, temp);
|
||||
id = sample_token(&candidates_p, rng);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (mirostat == 1)
|
||||
{
|
||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||
const int mirostat_m = 100;
|
||||
llama_sample_temperature(nullptr, &candidates_p, temp);
|
||||
id = sample_token_mirostat(n_vocab, &candidates_p, rng, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
|
||||
}
|
||||
else if (mirostat == 2)
|
||||
{
|
||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||
llama_sample_temperature(nullptr, &candidates_p, temp);
|
||||
id = sample_token_mirostat_v2(&candidates_p, rng, mirostat_tau, mirostat_eta, &mirostat_mu);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Temperature sampling
|
||||
llama_sample_top_k(nullptr, &candidates_p, top_k);
|
||||
llama_sample_tail_free(nullptr, &candidates_p, tfs);
|
||||
llama_sample_typical(nullptr, &candidates_p, typical_p);
|
||||
llama_sample_top_p(nullptr, &candidates_p, top_p);
|
||||
llama_sample_temperature(nullptr, &candidates_p, temp);
|
||||
id = sample_token(&candidates_p, rng);
|
||||
}
|
||||
}
|
||||
|
||||
return id;
|
||||
|
@ -647,7 +721,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
logits[29961] = 0;
|
||||
}
|
||||
|
||||
id = SampleLogits(logits, nctx, n_vocab, last_n_size, repeat_penalty, top_k, top_p, typical_p, tfs_z, temp, rng);
|
||||
id = SampleLogits(logits, nctx, n_vocab, last_n_size, repeat_penalty,
|
||||
top_k, top_p, typical_p, tfs_z, temp, rng,
|
||||
params.mirostat,params.mirostat_tau,params.mirostat_eta);
|
||||
|
||||
}
|
||||
else
|
||||
|
@ -667,7 +743,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
//gpt2 uses negative logits, so we cant zero it
|
||||
}
|
||||
|
||||
id = SampleLogits(logits.data(), nctx, n_vocab, last_n_size, repeat_penalty, top_k, top_p, typical_p, tfs_z, temp, rng);
|
||||
id = SampleLogits(logits.data(), nctx, n_vocab, last_n_size, repeat_penalty,
|
||||
top_k, top_p, typical_p, tfs_z, temp, rng,
|
||||
params.mirostat,params.mirostat_tau,params.mirostat_eta);
|
||||
}
|
||||
|
||||
last_n_tokens.erase(last_n_tokens.begin());
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue