feat: add a primitive form of continuous batching (#2167)

* feat: add a primitive form of continuous batching

* fix: deadlock in batching fallback

* fix: windows build

* chore: suppress the contbatch arg from --help

* feat: batch-aware rep_pen_slope

* fix: automatically disable shifting when batching is enabled

* fix: mixed-path state corruption

* fix: attempt to fully separate the two pipelines

* added a semaphore to prevent non-batchable requests from starting while batched requests are running

---------

Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com>
This commit is contained in:
AlpinDale 2026-05-10 14:20:31 +04:30 committed by GitHub
parent a47037637c
commit c03302b670
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 924 additions and 16 deletions

View file

@ -22,6 +22,11 @@
#include <cctype>
#include <locale>
#include <chrono>
#include <algorithm>
#include <condition_variable>
#include <deque>
#include <memory>
#include <thread>
#include "utils.h"
#include "llmutils.h"
@ -76,6 +81,7 @@ int last_draft_success = 0;
int last_draft_failed = 0;
stop_reason last_stop_reason = stop_reason::INVALID;
std::vector<std::string> generated_tokens;
static int continuous_batching_slots = 0;
llama_grammar * grammar = nullptr; //currently used grammar
llama_grammar_parser parsed_grammar;
@ -2197,6 +2203,11 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
kcpp_pipeline_parallelism = inputs.pipelineparallel;
kcpp_data->n_batch = GetBatchSize(inputs.batchsize, in_file_format);
kcpp_data->n_ubatch = kcpp_data->n_batch;
continuous_batching_slots = (isGguf && inputs.continuous_batching_slots > 1) ? inputs.continuous_batching_slots : 0;
if(continuous_batching_slots > 0)
{
printf("Continuous batching: prepared %d GGUF sequence slots.\n", continuous_batching_slots);
}
kcpp_data->vision_min_tokens = inputs.visionmintokens;
kcpp_data->vision_max_tokens = inputs.visionmaxtokens;
if(isGguf && kcpp_pipeline_parallelism)
@ -2461,6 +2472,10 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
llama_ctx_params.n_batch = kcpp_data->n_batch;
llama_ctx_params.n_ubatch = kcpp_data->n_ubatch;
if(continuous_batching_slots > 0)
{
llama_ctx_params.n_seq_max = continuous_batching_slots + 1;
}
llama_ctx_params.n_threads = kcpp_data->n_threads;
llama_ctx_params.n_threads_batch = kcpp_data->n_blasthreads;
@ -3269,6 +3284,745 @@ bool gpttype_generate_abort()
return true;
}
enum class BatchState
{
WAITING,
PREFILL,
GENERATING,
FINISHED,
FAILED,
ABORTED,
};
struct BatchGenerateRequest
{
int id = 0;
int slot = -1;
BatchState state = BatchState::WAITING;
std::string prompt;
std::vector<std::string> stop_sequences;
int max_context_length = 0;
int max_length = 0;
int seed = 0;
float temperature = 0.0f;
int top_k = 0;
float top_p = 1.0f;
float min_p = 0.0f;
float typical_p = 1.0f;
float rep_pen = 1.0f;
float rep_pen_slope = 1.0f;
int rep_pen_range = 0;
float presence_penalty = 0.0f;
bool allow_eos_token = true;
bool bypass_eos_token = false;
bool render_special = false;
std::vector<llama_token> prompt_tokens;
int prompt_pos = 0;
int n_past = 0;
bool has_pending = false;
llama_token pending_token = 0;
int i_batch = -1;
llama_sampler * sampler = nullptr;
std::vector<std::string> generated_pieces;
std::string output;
int prompt_token_count = 0;
int completion_token_count = 0;
std::chrono::steady_clock::time_point start_time;
stop_reason finish_reason = stop_reason::INVALID;
bool abort_requested = false;
generation_outputs result;
~BatchGenerateRequest()
{
if(sampler)
{
llama_sampler_free(sampler);
sampler = nullptr;
}
}
};
static std::mutex batch_mutex;
static std::condition_variable batch_cv;
static std::deque<int> batch_waiting;
static std::vector<std::unique_ptr<BatchGenerateRequest>> batch_requests;
static std::thread batch_worker_thread;
static bool batch_worker_stop = false;
static bool batch_worker_started = false;
static bool batch_legacy_active = false;
static bool batch_touched_since_legacy = false;
static int batch_legacy_waiting = 0;
static int batch_next_request_id = 1;
static std::string batch_empty_string = "";
static BatchGenerateRequest * batch_find_request_locked(int request_id)
{
for(auto & req : batch_requests)
{
if(req && req->id == request_id)
{
return req.get();
}
}
return nullptr;
}
static bool batch_is_live_state(BatchState state)
{
return state == BatchState::WAITING || state == BatchState::PREFILL || state == BatchState::GENERATING;
}
static bool batch_has_live_locked()
{
for(const auto & req : batch_requests)
{
if(req && batch_is_live_state(req->state))
{
return true;
}
}
return false;
}
static void batch_invalidate_legacy_context_locked()
{
if(!batch_touched_since_legacy)
{
return;
}
batch_touched_since_legacy = false;
n_past = 0;
current_context_tokens.clear();
last_n_tokens.clear();
smartcontext.clear();
loaded_latest_logits.clear();
if(llama_ctx_v4)
{
llama_memory_seq_rm(llama_get_memory(llama_ctx_v4), 0, -1, -1);
}
if(draft_ctx)
{
llama_memory_seq_rm(llama_get_memory(draft_ctx), 0, -1, -1);
}
if(debugmode==1 && !is_quiet)
{
printf("\n[Continuous batching touched shared context; forcing next legacy generation to reprocess prompt]\n");
}
}
class BatchLegacyGuard
{
public:
BatchLegacyGuard()
{
std::unique_lock<std::mutex> lock(batch_mutex);
batch_legacy_waiting++;
batch_cv.notify_all();
batch_cv.wait(lock, [](){ return !batch_has_live_locked(); });
batch_legacy_waiting--;
batch_invalidate_legacy_context_locked();
batch_legacy_active = true;
}
~BatchLegacyGuard()
{
std::lock_guard<std::mutex> lock(batch_mutex);
batch_legacy_active = false;
batch_cv.notify_all();
}
};
static bool batch_inputs_eligible(const generation_inputs & inputs)
{
if(continuous_batching_slots <= 1 || file_format != FileFormat::GGUF_GENERIC || !llama_ctx_v4 || !kcpp_data)
{
return false;
}
if(draft_ctx || guidance_ctx || clp_ctx_v || clp_ctx_a)
{
return false;
}
if(kcpp_data->use_smartcontext || kcpp_data->use_contextshift || kcpp_data->smartcache)
{
return false;
}
if(inputs.memory && std::string(inputs.memory).size() > 0)
{
return false;
}
if(inputs.negative_prompt && std::string(inputs.negative_prompt).size() > 0)
{
return false;
}
if(inputs.images_len > 0 || inputs.audio_len > 0 || inputs.guidance_scale != 1.0f || !inputs.allow_eos_token)
{
return false;
}
if(inputs.grammar && std::string(inputs.grammar).size() > 0)
{
return false;
}
if(inputs.logit_biases_len > 0 || inputs.banned_tokens_len > 0 || inputs.dry_multiplier > 0.0f)
{
return false;
}
if(inputs.mirostat != 0 || inputs.xtc_probability > 0.0f || inputs.nsigma > 0.0f || inputs.smoothing_factor > 0.0f || inputs.adaptive_target > 0.0f)
{
return false;
}
if(inputs.top_a > 0.0f || inputs.tfs != 1.0f || inputs.dynatemp_range > 0.0f)
{
return false;
}
static const int default_sampler_order[] = {6, 0, 1, 3, 4, 2, 5};
if(inputs.sampler_len > 0)
{
if(inputs.sampler_len != 7)
{
return false;
}
for(int i = 0; i < 7; ++i)
{
if((int) inputs.sampler_order[i] != default_sampler_order[i])
{
return false;
}
}
}
if(inputs.reasoning_budget >= 0 || inputs.tool_call_fix)
{
return false;
}
return true;
}
struct BatchRepPenSampler
{
int32_t penalty_last_n = 0;
float penalty_repeat = 1.0f;
float penalty_slope = 1.0f;
float penalty_present = 0.0f;
std::vector<llama_token> prev;
};
static const char * batch_rep_pen_name(const llama_sampler * /*smpl*/)
{
return "kcpp-batch-rep-pen";
}
static void batch_rep_pen_accept(llama_sampler * smpl, llama_token token)
{
auto * ctx = (BatchRepPenSampler *) smpl->ctx;
if(ctx->penalty_last_n <= 0)
{
return;
}
if(ctx->prev.size() >= (size_t) ctx->penalty_last_n)
{
ctx->prev.erase(ctx->prev.begin());
}
ctx->prev.push_back(token);
}
static void batch_rep_pen_apply(llama_sampler * smpl, llama_token_data_array * cur_p)
{
auto * ctx = (BatchRepPenSampler *) smpl->ctx;
int last_n_repeat = std::min((int) ctx->prev.size(), ctx->penalty_last_n);
if(last_n_repeat <= 0 || (ctx->penalty_repeat == 1.0f && ctx->penalty_present == 0.0f))
{
return;
}
const llama_token * last_tokens = ctx->prev.data() + ctx->prev.size() - last_n_repeat;
std::unordered_set<llama_token> tokens_near(last_tokens + last_n_repeat / 2, last_tokens + last_n_repeat);
std::unordered_set<llama_token> tokens_far(last_tokens, last_tokens + last_n_repeat / 2);
float penalty_reduced = ctx->penalty_repeat;
if(penalty_reduced > 1.0f)
{
penalty_reduced = 1.0f + ((ctx->penalty_repeat - 1.0f) * ctx->penalty_slope);
}
for(size_t i = 0; i < cur_p->size; ++i)
{
const bool token_in_near = tokens_near.find(cur_p->data[i].id) != tokens_near.end();
const bool token_in_far = tokens_far.find(cur_p->data[i].id) != tokens_far.end();
if(!token_in_near && !token_in_far)
{
continue;
}
float penalty = token_in_near ? ctx->penalty_repeat : penalty_reduced;
if(cur_p->data[i].logit <= 0)
{
cur_p->data[i].logit *= penalty;
}
else
{
cur_p->data[i].logit /= penalty;
}
cur_p->data[i].logit -= ctx->penalty_present;
}
cur_p->sorted = false;
}
static void batch_rep_pen_reset(llama_sampler * smpl)
{
auto * ctx = (BatchRepPenSampler *) smpl->ctx;
ctx->prev.clear();
}
static llama_sampler * batch_rep_pen_clone(const llama_sampler * smpl)
{
const auto * ctx = (const BatchRepPenSampler *) smpl->ctx;
auto * result = llama_sampler_init(smpl->iface, new BatchRepPenSampler {
ctx->penalty_last_n,
ctx->penalty_repeat,
ctx->penalty_slope,
ctx->penalty_present,
ctx->prev,
});
return result;
}
static void batch_rep_pen_free(llama_sampler * smpl)
{
delete (BatchRepPenSampler *) smpl->ctx;
}
static llama_sampler_i batch_rep_pen_i = {
/* .name = */ batch_rep_pen_name,
/* .accept = */ batch_rep_pen_accept,
/* .apply = */ batch_rep_pen_apply,
/* .reset = */ batch_rep_pen_reset,
/* .clone = */ batch_rep_pen_clone,
/* .free = */ batch_rep_pen_free,
/* .backend_init = */ nullptr,
/* .backend_accept = */ nullptr,
/* .backend_apply = */ nullptr,
/* .backend_set_input = */ nullptr,
};
static llama_sampler * batch_rep_pen_init(int32_t penalty_last_n, float penalty_repeat, float penalty_slope, float penalty_present)
{
penalty_last_n = std::max(penalty_last_n, 0);
if(penalty_slope <= 0.0f || penalty_slope > 1.0f)
{
penalty_slope = 1.0f;
}
return llama_sampler_init(&batch_rep_pen_i, new BatchRepPenSampler {
penalty_last_n,
penalty_repeat <= 0.0f ? 1.0f : penalty_repeat,
penalty_slope,
penalty_present,
{},
});
}
static llama_sampler * batch_build_sampler(const BatchGenerateRequest & req)
{
llama_sampler_chain_params params = llama_sampler_chain_default_params();
llama_sampler * chain = llama_sampler_chain_init(params);
llama_sampler_chain_add(chain, batch_rep_pen_init(
req.rep_pen_range,
req.rep_pen,
req.rep_pen_slope,
req.presence_penalty));
if(req.top_k > 0)
{
llama_sampler_chain_add(chain, llama_sampler_init_top_k(req.top_k));
}
if(req.top_p > 0.0f && req.top_p < 1.0f)
{
llama_sampler_chain_add(chain, llama_sampler_init_top_p(req.top_p, 1));
}
if(req.min_p > 0.0f)
{
llama_sampler_chain_add(chain, llama_sampler_init_min_p(req.min_p, 1));
}
if(req.typical_p > 0.0f && req.typical_p < 1.0f)
{
llama_sampler_chain_add(chain, llama_sampler_init_typical(req.typical_p, 1));
}
if(req.temperature > 0.0f)
{
llama_sampler_chain_add(chain, llama_sampler_init_temp(req.temperature));
llama_sampler_chain_add(chain, llama_sampler_init_dist(req.seed < 0 ? LLAMA_DEFAULT_SEED : (uint32_t) req.seed));
}
else
{
llama_sampler_chain_add(chain, llama_sampler_init_greedy());
}
return chain;
}
static void batch_finish_request_locked(BatchGenerateRequest & req, stop_reason reason)
{
auto finish_time = std::chrono::steady_clock::now();
float total_time = req.start_time.time_since_epoch().count() == 0 ? 0.0f : std::chrono::duration<float>(finish_time - req.start_time).count();
float generated_tps = total_time > 0.0f ? (float) req.completion_token_count / total_time : 0.0f;
req.finish_reason = reason;
req.result.status = (reason == stop_reason::ERROR_ENCOUNTERED) ? 0 : 1;
req.result.stopreason = reason;
req.result.prompt_tokens = req.prompt_token_count;
req.result.completion_tokens = req.completion_token_count;
req.result.text = req.output.c_str();
req.state = reason == stop_reason::ERROR_ENCOUNTERED ? BatchState::FAILED : (reason == stop_reason::INVALID ? BatchState::ABORTED : BatchState::FINISHED);
if(req.slot >= 0 && llama_ctx_v4)
{
llama_memory_seq_rm(llama_get_memory(llama_ctx_v4), req.slot, -1, -1);
}
req.slot = -1;
printf("\n[%s] BatchRequest:%d, Prompt:%d, Generated:%d/%d in %.2fs (%.2fT/s), Stop:%d",
get_timestamp_str().c_str(), req.id, req.prompt_token_count, req.completion_token_count, req.max_length, total_time, generated_tps, (int) reason);
fflush(stdout);
batch_cv.notify_all();
}
static bool batch_output_hit_stop(const BatchGenerateRequest & req)
{
for(const auto & stopper : req.stop_sequences)
{
if(!stopper.empty() && req.output.find(stopper) != std::string::npos)
{
return true;
}
}
return false;
}
static bool batch_claim_waiting_locked()
{
bool claimed = false;
for(int slot = 1; slot <= continuous_batching_slots && !batch_waiting.empty(); ++slot)
{
bool occupied = false;
for(const auto & req : batch_requests)
{
if(req && req->slot == slot && batch_is_live_state(req->state))
{
occupied = true;
break;
}
}
if(occupied)
{
continue;
}
int request_id = batch_waiting.front();
batch_waiting.pop_front();
BatchGenerateRequest * req = batch_find_request_locked(request_id);
if(!req || req->state != BatchState::WAITING)
{
continue;
}
req->slot = slot;
req->state = BatchState::PREFILL;
batch_touched_since_legacy = true;
TokenizeString(req->prompt, req->prompt_tokens, file_format, add_bos_token);
if(req->prompt_tokens.empty())
{
TokenizeString("", req->prompt_tokens, file_format, add_bos_token);
}
int n_ctx = req->max_context_length > 0 ? std::min(req->max_context_length, kcpp_data->n_ctx) : kcpp_data->n_ctx;
if(req->max_length > 0 && (int) req->prompt_tokens.size() + req->max_length > n_ctx)
{
int keep = std::max(1, n_ctx - req->max_length);
if((int) req->prompt_tokens.size() > keep)
{
req->prompt_tokens.erase(req->prompt_tokens.begin(), req->prompt_tokens.end() - keep);
}
}
req->prompt_token_count = req->prompt_tokens.size();
req->sampler = batch_build_sampler(*req);
for(llama_token token : req->prompt_tokens)
{
llama_sampler_accept(req->sampler, token);
}
req->prompt_pos = 0;
req->n_past = 0;
req->has_pending = false;
req->i_batch = -1;
req->start_time = std::chrono::steady_clock::now();
llama_memory_seq_rm(llama_get_memory(llama_ctx_v4), slot, -1, -1);
claimed = true;
}
return claimed;
}
static void batch_worker_loop()
{
const int batch_cap = std::max(1, kcpp_data ? kcpp_data->n_batch : 512);
llama_batch batch = llama_batch_init(batch_cap, 0, 1);
while(true)
{
std::vector<int> decode_ids;
{
std::unique_lock<std::mutex> lock(batch_mutex);
batch_cv.wait_for(lock, std::chrono::milliseconds(5), [](){
return batch_worker_stop || (!batch_legacy_active && batch_has_live_locked());
});
if(batch_worker_stop)
{
break;
}
if(batch_legacy_active)
{
continue;
}
batch_claim_waiting_locked();
common_batch_clear(batch);
for(auto & req_ptr : batch_requests)
{
if(!req_ptr || !batch_is_live_state(req_ptr->state) || req_ptr->slot < 0 || batch.n_tokens >= batch_cap)
{
continue;
}
BatchGenerateRequest & req = *req_ptr;
req.i_batch = -1;
if(req.abort_requested)
{
batch_finish_request_locked(req, stop_reason::INVALID);
continue;
}
if(req.state == BatchState::PREFILL)
{
while(req.prompt_pos < (int) req.prompt_tokens.size() && batch.n_tokens < batch_cap)
{
bool is_last = req.prompt_pos == (int) req.prompt_tokens.size() - 1;
if(is_last)
{
req.i_batch = batch.n_tokens;
}
common_batch_add(batch, req.prompt_tokens[req.prompt_pos], req.n_past, { req.slot }, is_last);
req.prompt_pos++;
req.n_past++;
}
if(req.prompt_pos == (int) req.prompt_tokens.size())
{
req.state = BatchState::GENERATING;
}
}
else if(req.state == BatchState::GENERATING && req.has_pending)
{
req.i_batch = batch.n_tokens;
common_batch_add(batch, req.pending_token, req.n_past, { req.slot }, true);
req.n_past++;
req.has_pending = false;
}
}
if(batch.n_tokens == 0)
{
continue;
}
for(auto & req_ptr : batch_requests)
{
if(req_ptr && req_ptr->i_batch >= 0)
{
decode_ids.push_back(req_ptr->id);
}
}
}
int decode_status = llama_decode(llama_ctx_v4, batch);
std::lock_guard<std::mutex> lock(batch_mutex);
if(decode_status != 0)
{
for(int request_id : decode_ids)
{
BatchGenerateRequest * req = batch_find_request_locked(request_id);
if(req && batch_is_live_state(req->state))
{
batch_finish_request_locked(*req, stop_reason::ERROR_ENCOUNTERED);
}
}
continue;
}
const llama_vocab * vocab = llama_model_get_vocab(llama_get_model(llama_ctx_v4));
llama_token eos = llama_vocab_eos(vocab);
for(int request_id : decode_ids)
{
BatchGenerateRequest * req = batch_find_request_locked(request_id);
if(!req || req->state != BatchState::GENERATING || req->i_batch < 0)
{
continue;
}
llama_token sampled = llama_sampler_sample(req->sampler, llama_ctx_v4, req->i_batch);
if(!req->allow_eos_token && !req->bypass_eos_token && sampled == eos)
{
sampled = llama_sampler_sample(req->sampler, llama_ctx_v4, req->i_batch);
}
req->completion_token_count++;
if(sampled == eos && !req->bypass_eos_token)
{
batch_finish_request_locked(*req, stop_reason::EOS_TOKEN_HIT);
continue;
}
std::string piece = FileFormatTokenizeID(sampled, file_format, req->render_special);
req->generated_pieces.push_back(piece);
req->output += piece;
if(batch_output_hit_stop(*req))
{
batch_finish_request_locked(*req, stop_reason::CUSTOM_STOPPER);
continue;
}
if(req->max_length > 0 && req->completion_token_count >= req->max_length)
{
batch_finish_request_locked(*req, stop_reason::OUT_OF_TOKENS);
continue;
}
req->pending_token = sampled;
req->has_pending = true;
req->i_batch = -1;
}
}
llama_batch_free(batch);
}
static void batch_start_worker_locked()
{
if(batch_worker_started)
{
return;
}
batch_worker_stop = false;
batch_worker_thread = std::thread(batch_worker_loop);
batch_worker_thread.detach();
batch_worker_started = true;
}
bool gpttype_batch_generate_enabled()
{
return continuous_batching_slots > 1 && file_format == FileFormat::GGUF_GENERIC && llama_ctx_v4 && kcpp_data;
}
int gpttype_batch_generate_submit(const generation_inputs inputs)
{
if(!batch_inputs_eligible(inputs))
{
return -1;
}
std::lock_guard<std::mutex> lock(batch_mutex);
if(batch_legacy_active || batch_legacy_waiting > 0)
{
return -1;
}
auto req = std::make_unique<BatchGenerateRequest>();
req->id = batch_next_request_id++;
req->prompt = inputs.prompt ? inputs.prompt : "";
req->max_context_length = inputs.max_context_length;
req->max_length = inputs.max_length;
req->seed = inputs.seed;
req->temperature = inputs.temperature;
req->top_k = inputs.top_k;
req->top_p = inputs.top_p;
req->min_p = inputs.min_p;
req->typical_p = inputs.typical_p;
req->rep_pen = inputs.rep_pen;
req->rep_pen_slope = inputs.rep_pen_slope;
req->rep_pen_range = inputs.rep_pen_range;
req->presence_penalty = inputs.presence_penalty;
req->allow_eos_token = inputs.allow_eos_token;
req->bypass_eos_token = inputs.bypass_eos_token;
req->render_special = inputs.render_special;
for(int i = 0; i < inputs.stop_sequence_len; ++i)
{
if(inputs.stop_sequence[i])
{
req->stop_sequences.emplace_back(inputs.stop_sequence[i]);
}
}
int request_id = req->id;
batch_requests.emplace_back(std::move(req));
batch_waiting.push_back(request_id);
batch_start_worker_locked();
batch_cv.notify_all();
return request_id;
}
bool gpttype_batch_generate_has_finished(int request_id)
{
std::lock_guard<std::mutex> lock(batch_mutex);
BatchGenerateRequest * req = batch_find_request_locked(request_id);
return !req || !batch_is_live_state(req->state);
}
int gpttype_batch_generate_stream_count(int request_id)
{
std::lock_guard<std::mutex> lock(batch_mutex);
BatchGenerateRequest * req = batch_find_request_locked(request_id);
return req ? req->generated_pieces.size() : 0;
}
const char * gpttype_batch_generate_new_token(int request_id, int idx)
{
std::lock_guard<std::mutex> lock(batch_mutex);
BatchGenerateRequest * req = batch_find_request_locked(request_id);
if(!req || idx < 0 || idx >= (int) req->generated_pieces.size())
{
return nullptr;
}
return req->generated_pieces[idx].c_str();
}
const char * gpttype_batch_generate_pending_output(int request_id)
{
std::lock_guard<std::mutex> lock(batch_mutex);
BatchGenerateRequest * req = batch_find_request_locked(request_id);
if(!req)
{
return batch_empty_string.c_str();
}
return req->output.c_str();
}
generation_outputs gpttype_batch_generate_result(int request_id)
{
std::unique_lock<std::mutex> lock(batch_mutex);
batch_cv.wait(lock, [request_id](){
BatchGenerateRequest * req = batch_find_request_locked(request_id);
return !req || !batch_is_live_state(req->state);
});
BatchGenerateRequest * req = batch_find_request_locked(request_id);
if(!req)
{
generation_outputs output;
output.status = 0;
output.stopreason = stop_reason::ERROR_ENCOUNTERED;
output.prompt_tokens = 0;
output.completion_tokens = 0;
output.text = batch_empty_string.c_str();
return output;
}
req->result.text = req->output.c_str();
return req->result;
}
bool gpttype_batch_generate_abort(int request_id)
{
std::lock_guard<std::mutex> lock(batch_mutex);
BatchGenerateRequest * req = batch_find_request_locked(request_id);
if(!req)
{
return false;
}
req->abort_requested = true;
batch_cv.notify_all();
return true;
}
void gpttype_batch_generate_release(int request_id)
{
std::lock_guard<std::mutex> lock(batch_mutex);
batch_requests.erase(std::remove_if(batch_requests.begin(), batch_requests.end(), [request_id](const std::unique_ptr<BatchGenerateRequest> & req){
return req && req->id == request_id && !batch_is_live_state(req->state);
}), batch_requests.end());
batch_cv.notify_all();
}
std::string gpttype_get_chat_template()
{
if(kcpp_data==nullptr)
@ -3547,6 +4301,7 @@ int smartcache_quick_snapshot(int specific_slot = -1)
generation_outputs gpttype_generate(const generation_inputs inputs)
{
BatchLegacyGuard batch_legacy_guard;
generation_outputs output;
if(kcpp_data==nullptr)