mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-19 08:00:25 +00:00
feat: add a primitive form of continuous batching (#2167)
* feat: add a primitive form of continuous batching * fix: deadlock in batching fallback * fix: windows build * chore: suppress the contbatch arg from --help * feat: batch-aware rep_pen_slope * fix: automatically disable shifting when batching is enabled * fix: mixed-path state corruption * fix: attempt to fully separate the two pipelines * added a semaphore to prevent non-batchable requests from starting while batched requests are running --------- Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com>
This commit is contained in:
parent
a47037637c
commit
c03302b670
5 changed files with 924 additions and 16 deletions
|
|
@ -22,6 +22,11 @@
|
|||
#include <cctype>
|
||||
#include <locale>
|
||||
#include <chrono>
|
||||
#include <algorithm>
|
||||
#include <condition_variable>
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
|
||||
#include "utils.h"
|
||||
#include "llmutils.h"
|
||||
|
|
@ -76,6 +81,7 @@ int last_draft_success = 0;
|
|||
int last_draft_failed = 0;
|
||||
stop_reason last_stop_reason = stop_reason::INVALID;
|
||||
std::vector<std::string> generated_tokens;
|
||||
static int continuous_batching_slots = 0;
|
||||
|
||||
llama_grammar * grammar = nullptr; //currently used grammar
|
||||
llama_grammar_parser parsed_grammar;
|
||||
|
|
@ -2197,6 +2203,11 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
kcpp_pipeline_parallelism = inputs.pipelineparallel;
|
||||
kcpp_data->n_batch = GetBatchSize(inputs.batchsize, in_file_format);
|
||||
kcpp_data->n_ubatch = kcpp_data->n_batch;
|
||||
continuous_batching_slots = (isGguf && inputs.continuous_batching_slots > 1) ? inputs.continuous_batching_slots : 0;
|
||||
if(continuous_batching_slots > 0)
|
||||
{
|
||||
printf("Continuous batching: prepared %d GGUF sequence slots.\n", continuous_batching_slots);
|
||||
}
|
||||
kcpp_data->vision_min_tokens = inputs.visionmintokens;
|
||||
kcpp_data->vision_max_tokens = inputs.visionmaxtokens;
|
||||
if(isGguf && kcpp_pipeline_parallelism)
|
||||
|
|
@ -2461,6 +2472,10 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
|
||||
llama_ctx_params.n_batch = kcpp_data->n_batch;
|
||||
llama_ctx_params.n_ubatch = kcpp_data->n_ubatch;
|
||||
if(continuous_batching_slots > 0)
|
||||
{
|
||||
llama_ctx_params.n_seq_max = continuous_batching_slots + 1;
|
||||
}
|
||||
llama_ctx_params.n_threads = kcpp_data->n_threads;
|
||||
llama_ctx_params.n_threads_batch = kcpp_data->n_blasthreads;
|
||||
|
||||
|
|
@ -3269,6 +3284,745 @@ bool gpttype_generate_abort()
|
|||
return true;
|
||||
}
|
||||
|
||||
enum class BatchState
|
||||
{
|
||||
WAITING,
|
||||
PREFILL,
|
||||
GENERATING,
|
||||
FINISHED,
|
||||
FAILED,
|
||||
ABORTED,
|
||||
};
|
||||
|
||||
struct BatchGenerateRequest
|
||||
{
|
||||
int id = 0;
|
||||
int slot = -1;
|
||||
BatchState state = BatchState::WAITING;
|
||||
std::string prompt;
|
||||
std::vector<std::string> stop_sequences;
|
||||
int max_context_length = 0;
|
||||
int max_length = 0;
|
||||
int seed = 0;
|
||||
float temperature = 0.0f;
|
||||
int top_k = 0;
|
||||
float top_p = 1.0f;
|
||||
float min_p = 0.0f;
|
||||
float typical_p = 1.0f;
|
||||
float rep_pen = 1.0f;
|
||||
float rep_pen_slope = 1.0f;
|
||||
int rep_pen_range = 0;
|
||||
float presence_penalty = 0.0f;
|
||||
bool allow_eos_token = true;
|
||||
bool bypass_eos_token = false;
|
||||
bool render_special = false;
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
int prompt_pos = 0;
|
||||
int n_past = 0;
|
||||
bool has_pending = false;
|
||||
llama_token pending_token = 0;
|
||||
int i_batch = -1;
|
||||
llama_sampler * sampler = nullptr;
|
||||
std::vector<std::string> generated_pieces;
|
||||
std::string output;
|
||||
int prompt_token_count = 0;
|
||||
int completion_token_count = 0;
|
||||
std::chrono::steady_clock::time_point start_time;
|
||||
stop_reason finish_reason = stop_reason::INVALID;
|
||||
bool abort_requested = false;
|
||||
generation_outputs result;
|
||||
|
||||
~BatchGenerateRequest()
|
||||
{
|
||||
if(sampler)
|
||||
{
|
||||
llama_sampler_free(sampler);
|
||||
sampler = nullptr;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static std::mutex batch_mutex;
|
||||
static std::condition_variable batch_cv;
|
||||
static std::deque<int> batch_waiting;
|
||||
static std::vector<std::unique_ptr<BatchGenerateRequest>> batch_requests;
|
||||
static std::thread batch_worker_thread;
|
||||
static bool batch_worker_stop = false;
|
||||
static bool batch_worker_started = false;
|
||||
static bool batch_legacy_active = false;
|
||||
static bool batch_touched_since_legacy = false;
|
||||
static int batch_legacy_waiting = 0;
|
||||
static int batch_next_request_id = 1;
|
||||
static std::string batch_empty_string = "";
|
||||
|
||||
static BatchGenerateRequest * batch_find_request_locked(int request_id)
|
||||
{
|
||||
for(auto & req : batch_requests)
|
||||
{
|
||||
if(req && req->id == request_id)
|
||||
{
|
||||
return req.get();
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static bool batch_is_live_state(BatchState state)
|
||||
{
|
||||
return state == BatchState::WAITING || state == BatchState::PREFILL || state == BatchState::GENERATING;
|
||||
}
|
||||
|
||||
static bool batch_has_live_locked()
|
||||
{
|
||||
for(const auto & req : batch_requests)
|
||||
{
|
||||
if(req && batch_is_live_state(req->state))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static void batch_invalidate_legacy_context_locked()
|
||||
{
|
||||
if(!batch_touched_since_legacy)
|
||||
{
|
||||
return;
|
||||
}
|
||||
batch_touched_since_legacy = false;
|
||||
n_past = 0;
|
||||
current_context_tokens.clear();
|
||||
last_n_tokens.clear();
|
||||
smartcontext.clear();
|
||||
loaded_latest_logits.clear();
|
||||
if(llama_ctx_v4)
|
||||
{
|
||||
llama_memory_seq_rm(llama_get_memory(llama_ctx_v4), 0, -1, -1);
|
||||
}
|
||||
if(draft_ctx)
|
||||
{
|
||||
llama_memory_seq_rm(llama_get_memory(draft_ctx), 0, -1, -1);
|
||||
}
|
||||
if(debugmode==1 && !is_quiet)
|
||||
{
|
||||
printf("\n[Continuous batching touched shared context; forcing next legacy generation to reprocess prompt]\n");
|
||||
}
|
||||
}
|
||||
|
||||
class BatchLegacyGuard
|
||||
{
|
||||
public:
|
||||
BatchLegacyGuard()
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(batch_mutex);
|
||||
batch_legacy_waiting++;
|
||||
batch_cv.notify_all();
|
||||
batch_cv.wait(lock, [](){ return !batch_has_live_locked(); });
|
||||
batch_legacy_waiting--;
|
||||
batch_invalidate_legacy_context_locked();
|
||||
batch_legacy_active = true;
|
||||
}
|
||||
|
||||
~BatchLegacyGuard()
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(batch_mutex);
|
||||
batch_legacy_active = false;
|
||||
batch_cv.notify_all();
|
||||
}
|
||||
};
|
||||
|
||||
static bool batch_inputs_eligible(const generation_inputs & inputs)
|
||||
{
|
||||
if(continuous_batching_slots <= 1 || file_format != FileFormat::GGUF_GENERIC || !llama_ctx_v4 || !kcpp_data)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(draft_ctx || guidance_ctx || clp_ctx_v || clp_ctx_a)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(kcpp_data->use_smartcontext || kcpp_data->use_contextshift || kcpp_data->smartcache)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(inputs.memory && std::string(inputs.memory).size() > 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(inputs.negative_prompt && std::string(inputs.negative_prompt).size() > 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(inputs.images_len > 0 || inputs.audio_len > 0 || inputs.guidance_scale != 1.0f || !inputs.allow_eos_token)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(inputs.grammar && std::string(inputs.grammar).size() > 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(inputs.logit_biases_len > 0 || inputs.banned_tokens_len > 0 || inputs.dry_multiplier > 0.0f)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(inputs.mirostat != 0 || inputs.xtc_probability > 0.0f || inputs.nsigma > 0.0f || inputs.smoothing_factor > 0.0f || inputs.adaptive_target > 0.0f)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(inputs.top_a > 0.0f || inputs.tfs != 1.0f || inputs.dynatemp_range > 0.0f)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
static const int default_sampler_order[] = {6, 0, 1, 3, 4, 2, 5};
|
||||
if(inputs.sampler_len > 0)
|
||||
{
|
||||
if(inputs.sampler_len != 7)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
for(int i = 0; i < 7; ++i)
|
||||
{
|
||||
if((int) inputs.sampler_order[i] != default_sampler_order[i])
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if(inputs.reasoning_budget >= 0 || inputs.tool_call_fix)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
struct BatchRepPenSampler
|
||||
{
|
||||
int32_t penalty_last_n = 0;
|
||||
float penalty_repeat = 1.0f;
|
||||
float penalty_slope = 1.0f;
|
||||
float penalty_present = 0.0f;
|
||||
std::vector<llama_token> prev;
|
||||
};
|
||||
|
||||
static const char * batch_rep_pen_name(const llama_sampler * /*smpl*/)
|
||||
{
|
||||
return "kcpp-batch-rep-pen";
|
||||
}
|
||||
|
||||
static void batch_rep_pen_accept(llama_sampler * smpl, llama_token token)
|
||||
{
|
||||
auto * ctx = (BatchRepPenSampler *) smpl->ctx;
|
||||
if(ctx->penalty_last_n <= 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
if(ctx->prev.size() >= (size_t) ctx->penalty_last_n)
|
||||
{
|
||||
ctx->prev.erase(ctx->prev.begin());
|
||||
}
|
||||
ctx->prev.push_back(token);
|
||||
}
|
||||
|
||||
static void batch_rep_pen_apply(llama_sampler * smpl, llama_token_data_array * cur_p)
|
||||
{
|
||||
auto * ctx = (BatchRepPenSampler *) smpl->ctx;
|
||||
int last_n_repeat = std::min((int) ctx->prev.size(), ctx->penalty_last_n);
|
||||
if(last_n_repeat <= 0 || (ctx->penalty_repeat == 1.0f && ctx->penalty_present == 0.0f))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const llama_token * last_tokens = ctx->prev.data() + ctx->prev.size() - last_n_repeat;
|
||||
std::unordered_set<llama_token> tokens_near(last_tokens + last_n_repeat / 2, last_tokens + last_n_repeat);
|
||||
std::unordered_set<llama_token> tokens_far(last_tokens, last_tokens + last_n_repeat / 2);
|
||||
|
||||
float penalty_reduced = ctx->penalty_repeat;
|
||||
if(penalty_reduced > 1.0f)
|
||||
{
|
||||
penalty_reduced = 1.0f + ((ctx->penalty_repeat - 1.0f) * ctx->penalty_slope);
|
||||
}
|
||||
|
||||
for(size_t i = 0; i < cur_p->size; ++i)
|
||||
{
|
||||
const bool token_in_near = tokens_near.find(cur_p->data[i].id) != tokens_near.end();
|
||||
const bool token_in_far = tokens_far.find(cur_p->data[i].id) != tokens_far.end();
|
||||
if(!token_in_near && !token_in_far)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
float penalty = token_in_near ? ctx->penalty_repeat : penalty_reduced;
|
||||
if(cur_p->data[i].logit <= 0)
|
||||
{
|
||||
cur_p->data[i].logit *= penalty;
|
||||
}
|
||||
else
|
||||
{
|
||||
cur_p->data[i].logit /= penalty;
|
||||
}
|
||||
cur_p->data[i].logit -= ctx->penalty_present;
|
||||
}
|
||||
|
||||
cur_p->sorted = false;
|
||||
}
|
||||
|
||||
static void batch_rep_pen_reset(llama_sampler * smpl)
|
||||
{
|
||||
auto * ctx = (BatchRepPenSampler *) smpl->ctx;
|
||||
ctx->prev.clear();
|
||||
}
|
||||
|
||||
static llama_sampler * batch_rep_pen_clone(const llama_sampler * smpl)
|
||||
{
|
||||
const auto * ctx = (const BatchRepPenSampler *) smpl->ctx;
|
||||
auto * result = llama_sampler_init(smpl->iface, new BatchRepPenSampler {
|
||||
ctx->penalty_last_n,
|
||||
ctx->penalty_repeat,
|
||||
ctx->penalty_slope,
|
||||
ctx->penalty_present,
|
||||
ctx->prev,
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
static void batch_rep_pen_free(llama_sampler * smpl)
|
||||
{
|
||||
delete (BatchRepPenSampler *) smpl->ctx;
|
||||
}
|
||||
|
||||
static llama_sampler_i batch_rep_pen_i = {
|
||||
/* .name = */ batch_rep_pen_name,
|
||||
/* .accept = */ batch_rep_pen_accept,
|
||||
/* .apply = */ batch_rep_pen_apply,
|
||||
/* .reset = */ batch_rep_pen_reset,
|
||||
/* .clone = */ batch_rep_pen_clone,
|
||||
/* .free = */ batch_rep_pen_free,
|
||||
/* .backend_init = */ nullptr,
|
||||
/* .backend_accept = */ nullptr,
|
||||
/* .backend_apply = */ nullptr,
|
||||
/* .backend_set_input = */ nullptr,
|
||||
};
|
||||
|
||||
static llama_sampler * batch_rep_pen_init(int32_t penalty_last_n, float penalty_repeat, float penalty_slope, float penalty_present)
|
||||
{
|
||||
penalty_last_n = std::max(penalty_last_n, 0);
|
||||
if(penalty_slope <= 0.0f || penalty_slope > 1.0f)
|
||||
{
|
||||
penalty_slope = 1.0f;
|
||||
}
|
||||
return llama_sampler_init(&batch_rep_pen_i, new BatchRepPenSampler {
|
||||
penalty_last_n,
|
||||
penalty_repeat <= 0.0f ? 1.0f : penalty_repeat,
|
||||
penalty_slope,
|
||||
penalty_present,
|
||||
{},
|
||||
});
|
||||
}
|
||||
|
||||
static llama_sampler * batch_build_sampler(const BatchGenerateRequest & req)
|
||||
{
|
||||
llama_sampler_chain_params params = llama_sampler_chain_default_params();
|
||||
llama_sampler * chain = llama_sampler_chain_init(params);
|
||||
llama_sampler_chain_add(chain, batch_rep_pen_init(
|
||||
req.rep_pen_range,
|
||||
req.rep_pen,
|
||||
req.rep_pen_slope,
|
||||
req.presence_penalty));
|
||||
if(req.top_k > 0)
|
||||
{
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_top_k(req.top_k));
|
||||
}
|
||||
if(req.top_p > 0.0f && req.top_p < 1.0f)
|
||||
{
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_top_p(req.top_p, 1));
|
||||
}
|
||||
if(req.min_p > 0.0f)
|
||||
{
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_min_p(req.min_p, 1));
|
||||
}
|
||||
if(req.typical_p > 0.0f && req.typical_p < 1.0f)
|
||||
{
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_typical(req.typical_p, 1));
|
||||
}
|
||||
if(req.temperature > 0.0f)
|
||||
{
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_temp(req.temperature));
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_dist(req.seed < 0 ? LLAMA_DEFAULT_SEED : (uint32_t) req.seed));
|
||||
}
|
||||
else
|
||||
{
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_greedy());
|
||||
}
|
||||
return chain;
|
||||
}
|
||||
|
||||
static void batch_finish_request_locked(BatchGenerateRequest & req, stop_reason reason)
|
||||
{
|
||||
auto finish_time = std::chrono::steady_clock::now();
|
||||
float total_time = req.start_time.time_since_epoch().count() == 0 ? 0.0f : std::chrono::duration<float>(finish_time - req.start_time).count();
|
||||
float generated_tps = total_time > 0.0f ? (float) req.completion_token_count / total_time : 0.0f;
|
||||
req.finish_reason = reason;
|
||||
req.result.status = (reason == stop_reason::ERROR_ENCOUNTERED) ? 0 : 1;
|
||||
req.result.stopreason = reason;
|
||||
req.result.prompt_tokens = req.prompt_token_count;
|
||||
req.result.completion_tokens = req.completion_token_count;
|
||||
req.result.text = req.output.c_str();
|
||||
req.state = reason == stop_reason::ERROR_ENCOUNTERED ? BatchState::FAILED : (reason == stop_reason::INVALID ? BatchState::ABORTED : BatchState::FINISHED);
|
||||
if(req.slot >= 0 && llama_ctx_v4)
|
||||
{
|
||||
llama_memory_seq_rm(llama_get_memory(llama_ctx_v4), req.slot, -1, -1);
|
||||
}
|
||||
req.slot = -1;
|
||||
printf("\n[%s] BatchRequest:%d, Prompt:%d, Generated:%d/%d in %.2fs (%.2fT/s), Stop:%d",
|
||||
get_timestamp_str().c_str(), req.id, req.prompt_token_count, req.completion_token_count, req.max_length, total_time, generated_tps, (int) reason);
|
||||
fflush(stdout);
|
||||
batch_cv.notify_all();
|
||||
}
|
||||
|
||||
static bool batch_output_hit_stop(const BatchGenerateRequest & req)
|
||||
{
|
||||
for(const auto & stopper : req.stop_sequences)
|
||||
{
|
||||
if(!stopper.empty() && req.output.find(stopper) != std::string::npos)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool batch_claim_waiting_locked()
|
||||
{
|
||||
bool claimed = false;
|
||||
for(int slot = 1; slot <= continuous_batching_slots && !batch_waiting.empty(); ++slot)
|
||||
{
|
||||
bool occupied = false;
|
||||
for(const auto & req : batch_requests)
|
||||
{
|
||||
if(req && req->slot == slot && batch_is_live_state(req->state))
|
||||
{
|
||||
occupied = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(occupied)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
int request_id = batch_waiting.front();
|
||||
batch_waiting.pop_front();
|
||||
BatchGenerateRequest * req = batch_find_request_locked(request_id);
|
||||
if(!req || req->state != BatchState::WAITING)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
req->slot = slot;
|
||||
req->state = BatchState::PREFILL;
|
||||
batch_touched_since_legacy = true;
|
||||
TokenizeString(req->prompt, req->prompt_tokens, file_format, add_bos_token);
|
||||
if(req->prompt_tokens.empty())
|
||||
{
|
||||
TokenizeString("", req->prompt_tokens, file_format, add_bos_token);
|
||||
}
|
||||
int n_ctx = req->max_context_length > 0 ? std::min(req->max_context_length, kcpp_data->n_ctx) : kcpp_data->n_ctx;
|
||||
if(req->max_length > 0 && (int) req->prompt_tokens.size() + req->max_length > n_ctx)
|
||||
{
|
||||
int keep = std::max(1, n_ctx - req->max_length);
|
||||
if((int) req->prompt_tokens.size() > keep)
|
||||
{
|
||||
req->prompt_tokens.erase(req->prompt_tokens.begin(), req->prompt_tokens.end() - keep);
|
||||
}
|
||||
}
|
||||
req->prompt_token_count = req->prompt_tokens.size();
|
||||
req->sampler = batch_build_sampler(*req);
|
||||
for(llama_token token : req->prompt_tokens)
|
||||
{
|
||||
llama_sampler_accept(req->sampler, token);
|
||||
}
|
||||
req->prompt_pos = 0;
|
||||
req->n_past = 0;
|
||||
req->has_pending = false;
|
||||
req->i_batch = -1;
|
||||
req->start_time = std::chrono::steady_clock::now();
|
||||
llama_memory_seq_rm(llama_get_memory(llama_ctx_v4), slot, -1, -1);
|
||||
claimed = true;
|
||||
}
|
||||
return claimed;
|
||||
}
|
||||
|
||||
static void batch_worker_loop()
|
||||
{
|
||||
const int batch_cap = std::max(1, kcpp_data ? kcpp_data->n_batch : 512);
|
||||
llama_batch batch = llama_batch_init(batch_cap, 0, 1);
|
||||
while(true)
|
||||
{
|
||||
std::vector<int> decode_ids;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(batch_mutex);
|
||||
batch_cv.wait_for(lock, std::chrono::milliseconds(5), [](){
|
||||
return batch_worker_stop || (!batch_legacy_active && batch_has_live_locked());
|
||||
});
|
||||
if(batch_worker_stop)
|
||||
{
|
||||
break;
|
||||
}
|
||||
if(batch_legacy_active)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
batch_claim_waiting_locked();
|
||||
common_batch_clear(batch);
|
||||
for(auto & req_ptr : batch_requests)
|
||||
{
|
||||
if(!req_ptr || !batch_is_live_state(req_ptr->state) || req_ptr->slot < 0 || batch.n_tokens >= batch_cap)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
BatchGenerateRequest & req = *req_ptr;
|
||||
req.i_batch = -1;
|
||||
if(req.abort_requested)
|
||||
{
|
||||
batch_finish_request_locked(req, stop_reason::INVALID);
|
||||
continue;
|
||||
}
|
||||
if(req.state == BatchState::PREFILL)
|
||||
{
|
||||
while(req.prompt_pos < (int) req.prompt_tokens.size() && batch.n_tokens < batch_cap)
|
||||
{
|
||||
bool is_last = req.prompt_pos == (int) req.prompt_tokens.size() - 1;
|
||||
if(is_last)
|
||||
{
|
||||
req.i_batch = batch.n_tokens;
|
||||
}
|
||||
common_batch_add(batch, req.prompt_tokens[req.prompt_pos], req.n_past, { req.slot }, is_last);
|
||||
req.prompt_pos++;
|
||||
req.n_past++;
|
||||
}
|
||||
if(req.prompt_pos == (int) req.prompt_tokens.size())
|
||||
{
|
||||
req.state = BatchState::GENERATING;
|
||||
}
|
||||
}
|
||||
else if(req.state == BatchState::GENERATING && req.has_pending)
|
||||
{
|
||||
req.i_batch = batch.n_tokens;
|
||||
common_batch_add(batch, req.pending_token, req.n_past, { req.slot }, true);
|
||||
req.n_past++;
|
||||
req.has_pending = false;
|
||||
}
|
||||
}
|
||||
if(batch.n_tokens == 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
for(auto & req_ptr : batch_requests)
|
||||
{
|
||||
if(req_ptr && req_ptr->i_batch >= 0)
|
||||
{
|
||||
decode_ids.push_back(req_ptr->id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int decode_status = llama_decode(llama_ctx_v4, batch);
|
||||
|
||||
std::lock_guard<std::mutex> lock(batch_mutex);
|
||||
if(decode_status != 0)
|
||||
{
|
||||
for(int request_id : decode_ids)
|
||||
{
|
||||
BatchGenerateRequest * req = batch_find_request_locked(request_id);
|
||||
if(req && batch_is_live_state(req->state))
|
||||
{
|
||||
batch_finish_request_locked(*req, stop_reason::ERROR_ENCOUNTERED);
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(llama_get_model(llama_ctx_v4));
|
||||
llama_token eos = llama_vocab_eos(vocab);
|
||||
for(int request_id : decode_ids)
|
||||
{
|
||||
BatchGenerateRequest * req = batch_find_request_locked(request_id);
|
||||
if(!req || req->state != BatchState::GENERATING || req->i_batch < 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
llama_token sampled = llama_sampler_sample(req->sampler, llama_ctx_v4, req->i_batch);
|
||||
if(!req->allow_eos_token && !req->bypass_eos_token && sampled == eos)
|
||||
{
|
||||
sampled = llama_sampler_sample(req->sampler, llama_ctx_v4, req->i_batch);
|
||||
}
|
||||
req->completion_token_count++;
|
||||
if(sampled == eos && !req->bypass_eos_token)
|
||||
{
|
||||
batch_finish_request_locked(*req, stop_reason::EOS_TOKEN_HIT);
|
||||
continue;
|
||||
}
|
||||
std::string piece = FileFormatTokenizeID(sampled, file_format, req->render_special);
|
||||
req->generated_pieces.push_back(piece);
|
||||
req->output += piece;
|
||||
if(batch_output_hit_stop(*req))
|
||||
{
|
||||
batch_finish_request_locked(*req, stop_reason::CUSTOM_STOPPER);
|
||||
continue;
|
||||
}
|
||||
if(req->max_length > 0 && req->completion_token_count >= req->max_length)
|
||||
{
|
||||
batch_finish_request_locked(*req, stop_reason::OUT_OF_TOKENS);
|
||||
continue;
|
||||
}
|
||||
req->pending_token = sampled;
|
||||
req->has_pending = true;
|
||||
req->i_batch = -1;
|
||||
}
|
||||
}
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
|
||||
static void batch_start_worker_locked()
|
||||
{
|
||||
if(batch_worker_started)
|
||||
{
|
||||
return;
|
||||
}
|
||||
batch_worker_stop = false;
|
||||
batch_worker_thread = std::thread(batch_worker_loop);
|
||||
batch_worker_thread.detach();
|
||||
batch_worker_started = true;
|
||||
}
|
||||
|
||||
bool gpttype_batch_generate_enabled()
|
||||
{
|
||||
return continuous_batching_slots > 1 && file_format == FileFormat::GGUF_GENERIC && llama_ctx_v4 && kcpp_data;
|
||||
}
|
||||
|
||||
int gpttype_batch_generate_submit(const generation_inputs inputs)
|
||||
{
|
||||
if(!batch_inputs_eligible(inputs))
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(batch_mutex);
|
||||
if(batch_legacy_active || batch_legacy_waiting > 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
auto req = std::make_unique<BatchGenerateRequest>();
|
||||
req->id = batch_next_request_id++;
|
||||
req->prompt = inputs.prompt ? inputs.prompt : "";
|
||||
req->max_context_length = inputs.max_context_length;
|
||||
req->max_length = inputs.max_length;
|
||||
req->seed = inputs.seed;
|
||||
req->temperature = inputs.temperature;
|
||||
req->top_k = inputs.top_k;
|
||||
req->top_p = inputs.top_p;
|
||||
req->min_p = inputs.min_p;
|
||||
req->typical_p = inputs.typical_p;
|
||||
req->rep_pen = inputs.rep_pen;
|
||||
req->rep_pen_slope = inputs.rep_pen_slope;
|
||||
req->rep_pen_range = inputs.rep_pen_range;
|
||||
req->presence_penalty = inputs.presence_penalty;
|
||||
req->allow_eos_token = inputs.allow_eos_token;
|
||||
req->bypass_eos_token = inputs.bypass_eos_token;
|
||||
req->render_special = inputs.render_special;
|
||||
for(int i = 0; i < inputs.stop_sequence_len; ++i)
|
||||
{
|
||||
if(inputs.stop_sequence[i])
|
||||
{
|
||||
req->stop_sequences.emplace_back(inputs.stop_sequence[i]);
|
||||
}
|
||||
}
|
||||
int request_id = req->id;
|
||||
batch_requests.emplace_back(std::move(req));
|
||||
batch_waiting.push_back(request_id);
|
||||
batch_start_worker_locked();
|
||||
batch_cv.notify_all();
|
||||
return request_id;
|
||||
}
|
||||
|
||||
bool gpttype_batch_generate_has_finished(int request_id)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(batch_mutex);
|
||||
BatchGenerateRequest * req = batch_find_request_locked(request_id);
|
||||
return !req || !batch_is_live_state(req->state);
|
||||
}
|
||||
|
||||
int gpttype_batch_generate_stream_count(int request_id)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(batch_mutex);
|
||||
BatchGenerateRequest * req = batch_find_request_locked(request_id);
|
||||
return req ? req->generated_pieces.size() : 0;
|
||||
}
|
||||
|
||||
const char * gpttype_batch_generate_new_token(int request_id, int idx)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(batch_mutex);
|
||||
BatchGenerateRequest * req = batch_find_request_locked(request_id);
|
||||
if(!req || idx < 0 || idx >= (int) req->generated_pieces.size())
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
return req->generated_pieces[idx].c_str();
|
||||
}
|
||||
|
||||
const char * gpttype_batch_generate_pending_output(int request_id)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(batch_mutex);
|
||||
BatchGenerateRequest * req = batch_find_request_locked(request_id);
|
||||
if(!req)
|
||||
{
|
||||
return batch_empty_string.c_str();
|
||||
}
|
||||
return req->output.c_str();
|
||||
}
|
||||
|
||||
generation_outputs gpttype_batch_generate_result(int request_id)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(batch_mutex);
|
||||
batch_cv.wait(lock, [request_id](){
|
||||
BatchGenerateRequest * req = batch_find_request_locked(request_id);
|
||||
return !req || !batch_is_live_state(req->state);
|
||||
});
|
||||
BatchGenerateRequest * req = batch_find_request_locked(request_id);
|
||||
if(!req)
|
||||
{
|
||||
generation_outputs output;
|
||||
output.status = 0;
|
||||
output.stopreason = stop_reason::ERROR_ENCOUNTERED;
|
||||
output.prompt_tokens = 0;
|
||||
output.completion_tokens = 0;
|
||||
output.text = batch_empty_string.c_str();
|
||||
return output;
|
||||
}
|
||||
req->result.text = req->output.c_str();
|
||||
return req->result;
|
||||
}
|
||||
|
||||
bool gpttype_batch_generate_abort(int request_id)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(batch_mutex);
|
||||
BatchGenerateRequest * req = batch_find_request_locked(request_id);
|
||||
if(!req)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
req->abort_requested = true;
|
||||
batch_cv.notify_all();
|
||||
return true;
|
||||
}
|
||||
|
||||
void gpttype_batch_generate_release(int request_id)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(batch_mutex);
|
||||
batch_requests.erase(std::remove_if(batch_requests.begin(), batch_requests.end(), [request_id](const std::unique_ptr<BatchGenerateRequest> & req){
|
||||
return req && req->id == request_id && !batch_is_live_state(req->state);
|
||||
}), batch_requests.end());
|
||||
batch_cv.notify_all();
|
||||
}
|
||||
|
||||
std::string gpttype_get_chat_template()
|
||||
{
|
||||
if(kcpp_data==nullptr)
|
||||
|
|
@ -3547,6 +4301,7 @@ int smartcache_quick_snapshot(int specific_slot = -1)
|
|||
|
||||
generation_outputs gpttype_generate(const generation_inputs inputs)
|
||||
{
|
||||
BatchLegacyGuard batch_legacy_guard;
|
||||
generation_outputs output;
|
||||
|
||||
if(kcpp_data==nullptr)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue