mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 09:34:37 +00:00
Alone in the darkness
They're coming for you I know they will try to catch me too Alone in the darkness They're calling for you There's nowhere to run for cover
This commit is contained in:
commit
94a5a27b85
44 changed files with 6803 additions and 2143 deletions
|
@ -132,6 +132,7 @@ struct slot_params {
|
|||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
|
||||
|
||||
int64_t t_max_prompt_ms = -1; // TODO: implement
|
||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||
|
@ -174,6 +175,8 @@ struct server_slot {
|
|||
std::vector<llama_token> prompt_tokens;
|
||||
std::vector<llama_token> extra_tokens;
|
||||
|
||||
size_t last_nl_pos = 0;
|
||||
|
||||
std::string generated_text;
|
||||
std::vector<llama_token> cache_tokens;
|
||||
std::vector<completion_token_output> generated_token_probs;
|
||||
|
@ -216,6 +219,7 @@ struct server_slot {
|
|||
SLT_DBG(*this, "%s", "\n");
|
||||
|
||||
n_prompt_tokens = 0;
|
||||
last_nl_pos = 0;
|
||||
generated_text = "";
|
||||
has_new_line = false;
|
||||
truncated = false;
|
||||
|
@ -861,6 +865,7 @@ struct server_context {
|
|||
slot.params.stream = json_value(data, "stream", false);
|
||||
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
|
||||
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
|
||||
slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent);
|
||||
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
||||
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
||||
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
||||
|
@ -879,7 +884,7 @@ struct server_context {
|
|||
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
||||
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
||||
slot.params.n_keep = json_value(data, "n_keep", default_params.n_keep);
|
||||
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
||||
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
||||
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||
|
@ -1130,13 +1135,48 @@ struct server_context {
|
|||
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
|
||||
}
|
||||
|
||||
// if we have already seen a new line, we stop after a certain time limit
|
||||
if (slot.has_new_line && slot.params.t_max_predict_ms > 0 &&
|
||||
(ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
|
||||
slot.stopped_limit = true;
|
||||
slot.has_next_token = false;
|
||||
if (slot.has_new_line) {
|
||||
// if we have already seen a new line, we stop after a certain time limit
|
||||
if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
|
||||
slot.stopped_limit = true;
|
||||
slot.has_next_token = false;
|
||||
|
||||
SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
|
||||
SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
|
||||
}
|
||||
|
||||
// require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
|
||||
if (slot.params.n_indent > 0) {
|
||||
// check the current indentation
|
||||
// TODO: improve by not doing it more than once for each new line
|
||||
if (slot.last_nl_pos > 0) {
|
||||
size_t pos = slot.last_nl_pos;
|
||||
|
||||
int n_indent = 0;
|
||||
while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
|
||||
n_indent++;
|
||||
pos++;
|
||||
}
|
||||
|
||||
if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) {
|
||||
slot.stopped_limit = true;
|
||||
slot.has_next_token = false;
|
||||
|
||||
// cut the last line
|
||||
slot.generated_text.erase(pos, std::string::npos);
|
||||
|
||||
SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent);
|
||||
}
|
||||
}
|
||||
|
||||
// find the next new line
|
||||
{
|
||||
const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos);
|
||||
|
||||
if (pos != std::string::npos) {
|
||||
slot.last_nl_pos = pos + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check if there is a new line in the generated text
|
||||
|
@ -2124,17 +2164,10 @@ struct server_context {
|
|||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||
}
|
||||
|
||||
common_sampler_reset(slot.smpl);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
// reuse any previously computed tokens that are common with the new prompt
|
||||
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
|
||||
|
||||
// push the prompt into the sampling context (do not apply grammar)
|
||||
for (int i = 0; i < slot.n_past; ++i) {
|
||||
common_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
|
||||
}
|
||||
|
||||
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
||||
if (params.n_cache_reuse > 0) {
|
||||
size_t head_c = slot.n_past; // cache
|
||||
|
@ -2167,8 +2200,6 @@ struct server_context {
|
|||
for (size_t i = 0; i < n_match; i++) {
|
||||
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
|
||||
|
||||
common_sampler_accept(slot.smpl, slot.cache_tokens[head_p + i], false);
|
||||
|
||||
slot.n_past++;
|
||||
}
|
||||
|
||||
|
@ -2220,8 +2251,6 @@ struct server_context {
|
|||
|
||||
// there is no common part left
|
||||
slot.n_past = 0;
|
||||
|
||||
common_sampler_reset(slot.smpl);
|
||||
}
|
||||
|
||||
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
|
||||
|
@ -2249,6 +2278,13 @@ struct server_context {
|
|||
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
|
||||
common_sampler_reset(slot.smpl);
|
||||
|
||||
// Process all prompt tokens through sampler system
|
||||
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
|
||||
common_sampler_accept(slot.smpl, prompt_tokens[i], false);
|
||||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
|
@ -2287,7 +2323,6 @@ struct server_context {
|
|||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue