feat: add a primitive form of continuous batching (#2167)

* feat: add a primitive form of continuous batching

* fix: deadlock in batching fallback

* fix: windows build

* chore: suppress the contbatch arg from --help

* feat: batch-aware rep_pen_slope

* fix: automatically disable shifting when batching is enabled

* fix: mixed-path state corruption

* fix: attempt to fully separate the two pipelines

* added a semaphore to prevent non-batchable requests from starting while batched requests are running

---------

Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com>
This commit is contained in:
AlpinDale 2026-05-10 14:20:31 +04:30 committed by GitHub
parent a47037637c
commit c03302b670
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 924 additions and 16 deletions

View file

@ -258,6 +258,33 @@ extern "C"
bool has_finished() {
return generation_finished;
}
bool batch_generate_enabled() {
return gpttype_batch_generate_enabled();
}
int batch_generate_submit(const generation_inputs inputs) {
return gpttype_batch_generate_submit(inputs);
}
bool batch_generate_has_finished(int request_id) {
return gpttype_batch_generate_has_finished(request_id);
}
int batch_generate_stream_count(int request_id) {
return gpttype_batch_generate_stream_count(request_id);
}
const char * batch_generate_new_token(int request_id, int idx) {
return gpttype_batch_generate_new_token(request_id, idx);
}
const char * batch_generate_pending_output(int request_id) {
return gpttype_batch_generate_pending_output(request_id);
}
generation_outputs batch_generate_result(int request_id) {
return gpttype_batch_generate_result(request_id);
}
bool batch_generate_abort(int request_id) {
return gpttype_batch_generate_abort(request_id);
}
void batch_generate_release(int request_id) {
gpttype_batch_generate_release(request_id);
}
bool has_audio_support()
{
return audio_multimodal_supported;

View file

@ -84,6 +84,7 @@ struct load_model_inputs
const char * devices_override = nullptr;
const bool quiet = false;
const int debugmode = 0;
const int continuous_batching_slots = 0;
};
struct generation_inputs
{
@ -385,3 +386,13 @@ extern int total_transcribe_gens;
extern int last_draft_success;
extern int last_draft_failed;
extern stop_reason last_stop_reason;
bool gpttype_batch_generate_enabled();
int gpttype_batch_generate_submit(const generation_inputs inputs);
bool gpttype_batch_generate_has_finished(int request_id);
int gpttype_batch_generate_stream_count(int request_id);
const char * gpttype_batch_generate_new_token(int request_id, int idx);
const char * gpttype_batch_generate_pending_output(int request_id);
generation_outputs gpttype_batch_generate_result(int request_id);
bool gpttype_batch_generate_abort(int request_id);
void gpttype_batch_generate_release(int request_id);

View file

@ -22,6 +22,11 @@
#include <cctype>
#include <locale>
#include <chrono>
#include <algorithm>
#include <condition_variable>
#include <deque>
#include <memory>
#include <thread>
#include "utils.h"
#include "llmutils.h"
@ -76,6 +81,7 @@ int last_draft_success = 0;
int last_draft_failed = 0;
stop_reason last_stop_reason = stop_reason::INVALID;
std::vector<std::string> generated_tokens;
static int continuous_batching_slots = 0;
llama_grammar * grammar = nullptr; //currently used grammar
llama_grammar_parser parsed_grammar;
@ -2197,6 +2203,11 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
kcpp_pipeline_parallelism = inputs.pipelineparallel;
kcpp_data->n_batch = GetBatchSize(inputs.batchsize, in_file_format);
kcpp_data->n_ubatch = kcpp_data->n_batch;
continuous_batching_slots = (isGguf && inputs.continuous_batching_slots > 1) ? inputs.continuous_batching_slots : 0;
if(continuous_batching_slots > 0)
{
printf("Continuous batching: prepared %d GGUF sequence slots.\n", continuous_batching_slots);
}
kcpp_data->vision_min_tokens = inputs.visionmintokens;
kcpp_data->vision_max_tokens = inputs.visionmaxtokens;
if(isGguf && kcpp_pipeline_parallelism)
@ -2461,6 +2472,10 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
llama_ctx_params.n_batch = kcpp_data->n_batch;
llama_ctx_params.n_ubatch = kcpp_data->n_ubatch;
if(continuous_batching_slots > 0)
{
llama_ctx_params.n_seq_max = continuous_batching_slots + 1;
}
llama_ctx_params.n_threads = kcpp_data->n_threads;
llama_ctx_params.n_threads_batch = kcpp_data->n_blasthreads;
@ -3269,6 +3284,745 @@ bool gpttype_generate_abort()
return true;
}
enum class BatchState
{
WAITING,
PREFILL,
GENERATING,
FINISHED,
FAILED,
ABORTED,
};
struct BatchGenerateRequest
{
int id = 0;
int slot = -1;
BatchState state = BatchState::WAITING;
std::string prompt;
std::vector<std::string> stop_sequences;
int max_context_length = 0;
int max_length = 0;
int seed = 0;
float temperature = 0.0f;
int top_k = 0;
float top_p = 1.0f;
float min_p = 0.0f;
float typical_p = 1.0f;
float rep_pen = 1.0f;
float rep_pen_slope = 1.0f;
int rep_pen_range = 0;
float presence_penalty = 0.0f;
bool allow_eos_token = true;
bool bypass_eos_token = false;
bool render_special = false;
std::vector<llama_token> prompt_tokens;
int prompt_pos = 0;
int n_past = 0;
bool has_pending = false;
llama_token pending_token = 0;
int i_batch = -1;
llama_sampler * sampler = nullptr;
std::vector<std::string> generated_pieces;
std::string output;
int prompt_token_count = 0;
int completion_token_count = 0;
std::chrono::steady_clock::time_point start_time;
stop_reason finish_reason = stop_reason::INVALID;
bool abort_requested = false;
generation_outputs result;
~BatchGenerateRequest()
{
if(sampler)
{
llama_sampler_free(sampler);
sampler = nullptr;
}
}
};
static std::mutex batch_mutex;
static std::condition_variable batch_cv;
static std::deque<int> batch_waiting;
static std::vector<std::unique_ptr<BatchGenerateRequest>> batch_requests;
static std::thread batch_worker_thread;
static bool batch_worker_stop = false;
static bool batch_worker_started = false;
static bool batch_legacy_active = false;
static bool batch_touched_since_legacy = false;
static int batch_legacy_waiting = 0;
static int batch_next_request_id = 1;
static std::string batch_empty_string = "";
static BatchGenerateRequest * batch_find_request_locked(int request_id)
{
for(auto & req : batch_requests)
{
if(req && req->id == request_id)
{
return req.get();
}
}
return nullptr;
}
static bool batch_is_live_state(BatchState state)
{
return state == BatchState::WAITING || state == BatchState::PREFILL || state == BatchState::GENERATING;
}
static bool batch_has_live_locked()
{
for(const auto & req : batch_requests)
{
if(req && batch_is_live_state(req->state))
{
return true;
}
}
return false;
}
static void batch_invalidate_legacy_context_locked()
{
if(!batch_touched_since_legacy)
{
return;
}
batch_touched_since_legacy = false;
n_past = 0;
current_context_tokens.clear();
last_n_tokens.clear();
smartcontext.clear();
loaded_latest_logits.clear();
if(llama_ctx_v4)
{
llama_memory_seq_rm(llama_get_memory(llama_ctx_v4), 0, -1, -1);
}
if(draft_ctx)
{
llama_memory_seq_rm(llama_get_memory(draft_ctx), 0, -1, -1);
}
if(debugmode==1 && !is_quiet)
{
printf("\n[Continuous batching touched shared context; forcing next legacy generation to reprocess prompt]\n");
}
}
class BatchLegacyGuard
{
public:
BatchLegacyGuard()
{
std::unique_lock<std::mutex> lock(batch_mutex);
batch_legacy_waiting++;
batch_cv.notify_all();
batch_cv.wait(lock, [](){ return !batch_has_live_locked(); });
batch_legacy_waiting--;
batch_invalidate_legacy_context_locked();
batch_legacy_active = true;
}
~BatchLegacyGuard()
{
std::lock_guard<std::mutex> lock(batch_mutex);
batch_legacy_active = false;
batch_cv.notify_all();
}
};
static bool batch_inputs_eligible(const generation_inputs & inputs)
{
if(continuous_batching_slots <= 1 || file_format != FileFormat::GGUF_GENERIC || !llama_ctx_v4 || !kcpp_data)
{
return false;
}
if(draft_ctx || guidance_ctx || clp_ctx_v || clp_ctx_a)
{
return false;
}
if(kcpp_data->use_smartcontext || kcpp_data->use_contextshift || kcpp_data->smartcache)
{
return false;
}
if(inputs.memory && std::string(inputs.memory).size() > 0)
{
return false;
}
if(inputs.negative_prompt && std::string(inputs.negative_prompt).size() > 0)
{
return false;
}
if(inputs.images_len > 0 || inputs.audio_len > 0 || inputs.guidance_scale != 1.0f || !inputs.allow_eos_token)
{
return false;
}
if(inputs.grammar && std::string(inputs.grammar).size() > 0)
{
return false;
}
if(inputs.logit_biases_len > 0 || inputs.banned_tokens_len > 0 || inputs.dry_multiplier > 0.0f)
{
return false;
}
if(inputs.mirostat != 0 || inputs.xtc_probability > 0.0f || inputs.nsigma > 0.0f || inputs.smoothing_factor > 0.0f || inputs.adaptive_target > 0.0f)
{
return false;
}
if(inputs.top_a > 0.0f || inputs.tfs != 1.0f || inputs.dynatemp_range > 0.0f)
{
return false;
}
static const int default_sampler_order[] = {6, 0, 1, 3, 4, 2, 5};
if(inputs.sampler_len > 0)
{
if(inputs.sampler_len != 7)
{
return false;
}
for(int i = 0; i < 7; ++i)
{
if((int) inputs.sampler_order[i] != default_sampler_order[i])
{
return false;
}
}
}
if(inputs.reasoning_budget >= 0 || inputs.tool_call_fix)
{
return false;
}
return true;
}
struct BatchRepPenSampler
{
int32_t penalty_last_n = 0;
float penalty_repeat = 1.0f;
float penalty_slope = 1.0f;
float penalty_present = 0.0f;
std::vector<llama_token> prev;
};
static const char * batch_rep_pen_name(const llama_sampler * /*smpl*/)
{
return "kcpp-batch-rep-pen";
}
static void batch_rep_pen_accept(llama_sampler * smpl, llama_token token)
{
auto * ctx = (BatchRepPenSampler *) smpl->ctx;
if(ctx->penalty_last_n <= 0)
{
return;
}
if(ctx->prev.size() >= (size_t) ctx->penalty_last_n)
{
ctx->prev.erase(ctx->prev.begin());
}
ctx->prev.push_back(token);
}
static void batch_rep_pen_apply(llama_sampler * smpl, llama_token_data_array * cur_p)
{
auto * ctx = (BatchRepPenSampler *) smpl->ctx;
int last_n_repeat = std::min((int) ctx->prev.size(), ctx->penalty_last_n);
if(last_n_repeat <= 0 || (ctx->penalty_repeat == 1.0f && ctx->penalty_present == 0.0f))
{
return;
}
const llama_token * last_tokens = ctx->prev.data() + ctx->prev.size() - last_n_repeat;
std::unordered_set<llama_token> tokens_near(last_tokens + last_n_repeat / 2, last_tokens + last_n_repeat);
std::unordered_set<llama_token> tokens_far(last_tokens, last_tokens + last_n_repeat / 2);
float penalty_reduced = ctx->penalty_repeat;
if(penalty_reduced > 1.0f)
{
penalty_reduced = 1.0f + ((ctx->penalty_repeat - 1.0f) * ctx->penalty_slope);
}
for(size_t i = 0; i < cur_p->size; ++i)
{
const bool token_in_near = tokens_near.find(cur_p->data[i].id) != tokens_near.end();
const bool token_in_far = tokens_far.find(cur_p->data[i].id) != tokens_far.end();
if(!token_in_near && !token_in_far)
{
continue;
}
float penalty = token_in_near ? ctx->penalty_repeat : penalty_reduced;
if(cur_p->data[i].logit <= 0)
{
cur_p->data[i].logit *= penalty;
}
else
{
cur_p->data[i].logit /= penalty;
}
cur_p->data[i].logit -= ctx->penalty_present;
}
cur_p->sorted = false;
}
static void batch_rep_pen_reset(llama_sampler * smpl)
{
auto * ctx = (BatchRepPenSampler *) smpl->ctx;
ctx->prev.clear();
}
static llama_sampler * batch_rep_pen_clone(const llama_sampler * smpl)
{
const auto * ctx = (const BatchRepPenSampler *) smpl->ctx;
auto * result = llama_sampler_init(smpl->iface, new BatchRepPenSampler {
ctx->penalty_last_n,
ctx->penalty_repeat,
ctx->penalty_slope,
ctx->penalty_present,
ctx->prev,
});
return result;
}
static void batch_rep_pen_free(llama_sampler * smpl)
{
delete (BatchRepPenSampler *) smpl->ctx;
}
static llama_sampler_i batch_rep_pen_i = {
/* .name = */ batch_rep_pen_name,
/* .accept = */ batch_rep_pen_accept,
/* .apply = */ batch_rep_pen_apply,
/* .reset = */ batch_rep_pen_reset,
/* .clone = */ batch_rep_pen_clone,
/* .free = */ batch_rep_pen_free,
/* .backend_init = */ nullptr,
/* .backend_accept = */ nullptr,
/* .backend_apply = */ nullptr,
/* .backend_set_input = */ nullptr,
};
static llama_sampler * batch_rep_pen_init(int32_t penalty_last_n, float penalty_repeat, float penalty_slope, float penalty_present)
{
penalty_last_n = std::max(penalty_last_n, 0);
if(penalty_slope <= 0.0f || penalty_slope > 1.0f)
{
penalty_slope = 1.0f;
}
return llama_sampler_init(&batch_rep_pen_i, new BatchRepPenSampler {
penalty_last_n,
penalty_repeat <= 0.0f ? 1.0f : penalty_repeat,
penalty_slope,
penalty_present,
{},
});
}
static llama_sampler * batch_build_sampler(const BatchGenerateRequest & req)
{
llama_sampler_chain_params params = llama_sampler_chain_default_params();
llama_sampler * chain = llama_sampler_chain_init(params);
llama_sampler_chain_add(chain, batch_rep_pen_init(
req.rep_pen_range,
req.rep_pen,
req.rep_pen_slope,
req.presence_penalty));
if(req.top_k > 0)
{
llama_sampler_chain_add(chain, llama_sampler_init_top_k(req.top_k));
}
if(req.top_p > 0.0f && req.top_p < 1.0f)
{
llama_sampler_chain_add(chain, llama_sampler_init_top_p(req.top_p, 1));
}
if(req.min_p > 0.0f)
{
llama_sampler_chain_add(chain, llama_sampler_init_min_p(req.min_p, 1));
}
if(req.typical_p > 0.0f && req.typical_p < 1.0f)
{
llama_sampler_chain_add(chain, llama_sampler_init_typical(req.typical_p, 1));
}
if(req.temperature > 0.0f)
{
llama_sampler_chain_add(chain, llama_sampler_init_temp(req.temperature));
llama_sampler_chain_add(chain, llama_sampler_init_dist(req.seed < 0 ? LLAMA_DEFAULT_SEED : (uint32_t) req.seed));
}
else
{
llama_sampler_chain_add(chain, llama_sampler_init_greedy());
}
return chain;
}
static void batch_finish_request_locked(BatchGenerateRequest & req, stop_reason reason)
{
auto finish_time = std::chrono::steady_clock::now();
float total_time = req.start_time.time_since_epoch().count() == 0 ? 0.0f : std::chrono::duration<float>(finish_time - req.start_time).count();
float generated_tps = total_time > 0.0f ? (float) req.completion_token_count / total_time : 0.0f;
req.finish_reason = reason;
req.result.status = (reason == stop_reason::ERROR_ENCOUNTERED) ? 0 : 1;
req.result.stopreason = reason;
req.result.prompt_tokens = req.prompt_token_count;
req.result.completion_tokens = req.completion_token_count;
req.result.text = req.output.c_str();
req.state = reason == stop_reason::ERROR_ENCOUNTERED ? BatchState::FAILED : (reason == stop_reason::INVALID ? BatchState::ABORTED : BatchState::FINISHED);
if(req.slot >= 0 && llama_ctx_v4)
{
llama_memory_seq_rm(llama_get_memory(llama_ctx_v4), req.slot, -1, -1);
}
req.slot = -1;
printf("\n[%s] BatchRequest:%d, Prompt:%d, Generated:%d/%d in %.2fs (%.2fT/s), Stop:%d",
get_timestamp_str().c_str(), req.id, req.prompt_token_count, req.completion_token_count, req.max_length, total_time, generated_tps, (int) reason);
fflush(stdout);
batch_cv.notify_all();
}
static bool batch_output_hit_stop(const BatchGenerateRequest & req)
{
for(const auto & stopper : req.stop_sequences)
{
if(!stopper.empty() && req.output.find(stopper) != std::string::npos)
{
return true;
}
}
return false;
}
static bool batch_claim_waiting_locked()
{
bool claimed = false;
for(int slot = 1; slot <= continuous_batching_slots && !batch_waiting.empty(); ++slot)
{
bool occupied = false;
for(const auto & req : batch_requests)
{
if(req && req->slot == slot && batch_is_live_state(req->state))
{
occupied = true;
break;
}
}
if(occupied)
{
continue;
}
int request_id = batch_waiting.front();
batch_waiting.pop_front();
BatchGenerateRequest * req = batch_find_request_locked(request_id);
if(!req || req->state != BatchState::WAITING)
{
continue;
}
req->slot = slot;
req->state = BatchState::PREFILL;
batch_touched_since_legacy = true;
TokenizeString(req->prompt, req->prompt_tokens, file_format, add_bos_token);
if(req->prompt_tokens.empty())
{
TokenizeString("", req->prompt_tokens, file_format, add_bos_token);
}
int n_ctx = req->max_context_length > 0 ? std::min(req->max_context_length, kcpp_data->n_ctx) : kcpp_data->n_ctx;
if(req->max_length > 0 && (int) req->prompt_tokens.size() + req->max_length > n_ctx)
{
int keep = std::max(1, n_ctx - req->max_length);
if((int) req->prompt_tokens.size() > keep)
{
req->prompt_tokens.erase(req->prompt_tokens.begin(), req->prompt_tokens.end() - keep);
}
}
req->prompt_token_count = req->prompt_tokens.size();
req->sampler = batch_build_sampler(*req);
for(llama_token token : req->prompt_tokens)
{
llama_sampler_accept(req->sampler, token);
}
req->prompt_pos = 0;
req->n_past = 0;
req->has_pending = false;
req->i_batch = -1;
req->start_time = std::chrono::steady_clock::now();
llama_memory_seq_rm(llama_get_memory(llama_ctx_v4), slot, -1, -1);
claimed = true;
}
return claimed;
}
static void batch_worker_loop()
{
const int batch_cap = std::max(1, kcpp_data ? kcpp_data->n_batch : 512);
llama_batch batch = llama_batch_init(batch_cap, 0, 1);
while(true)
{
std::vector<int> decode_ids;
{
std::unique_lock<std::mutex> lock(batch_mutex);
batch_cv.wait_for(lock, std::chrono::milliseconds(5), [](){
return batch_worker_stop || (!batch_legacy_active && batch_has_live_locked());
});
if(batch_worker_stop)
{
break;
}
if(batch_legacy_active)
{
continue;
}
batch_claim_waiting_locked();
common_batch_clear(batch);
for(auto & req_ptr : batch_requests)
{
if(!req_ptr || !batch_is_live_state(req_ptr->state) || req_ptr->slot < 0 || batch.n_tokens >= batch_cap)
{
continue;
}
BatchGenerateRequest & req = *req_ptr;
req.i_batch = -1;
if(req.abort_requested)
{
batch_finish_request_locked(req, stop_reason::INVALID);
continue;
}
if(req.state == BatchState::PREFILL)
{
while(req.prompt_pos < (int) req.prompt_tokens.size() && batch.n_tokens < batch_cap)
{
bool is_last = req.prompt_pos == (int) req.prompt_tokens.size() - 1;
if(is_last)
{
req.i_batch = batch.n_tokens;
}
common_batch_add(batch, req.prompt_tokens[req.prompt_pos], req.n_past, { req.slot }, is_last);
req.prompt_pos++;
req.n_past++;
}
if(req.prompt_pos == (int) req.prompt_tokens.size())
{
req.state = BatchState::GENERATING;
}
}
else if(req.state == BatchState::GENERATING && req.has_pending)
{
req.i_batch = batch.n_tokens;
common_batch_add(batch, req.pending_token, req.n_past, { req.slot }, true);
req.n_past++;
req.has_pending = false;
}
}
if(batch.n_tokens == 0)
{
continue;
}
for(auto & req_ptr : batch_requests)
{
if(req_ptr && req_ptr->i_batch >= 0)
{
decode_ids.push_back(req_ptr->id);
}
}
}
int decode_status = llama_decode(llama_ctx_v4, batch);
std::lock_guard<std::mutex> lock(batch_mutex);
if(decode_status != 0)
{
for(int request_id : decode_ids)
{
BatchGenerateRequest * req = batch_find_request_locked(request_id);
if(req && batch_is_live_state(req->state))
{
batch_finish_request_locked(*req, stop_reason::ERROR_ENCOUNTERED);
}
}
continue;
}
const llama_vocab * vocab = llama_model_get_vocab(llama_get_model(llama_ctx_v4));
llama_token eos = llama_vocab_eos(vocab);
for(int request_id : decode_ids)
{
BatchGenerateRequest * req = batch_find_request_locked(request_id);
if(!req || req->state != BatchState::GENERATING || req->i_batch < 0)
{
continue;
}
llama_token sampled = llama_sampler_sample(req->sampler, llama_ctx_v4, req->i_batch);
if(!req->allow_eos_token && !req->bypass_eos_token && sampled == eos)
{
sampled = llama_sampler_sample(req->sampler, llama_ctx_v4, req->i_batch);
}
req->completion_token_count++;
if(sampled == eos && !req->bypass_eos_token)
{
batch_finish_request_locked(*req, stop_reason::EOS_TOKEN_HIT);
continue;
}
std::string piece = FileFormatTokenizeID(sampled, file_format, req->render_special);
req->generated_pieces.push_back(piece);
req->output += piece;
if(batch_output_hit_stop(*req))
{
batch_finish_request_locked(*req, stop_reason::CUSTOM_STOPPER);
continue;
}
if(req->max_length > 0 && req->completion_token_count >= req->max_length)
{
batch_finish_request_locked(*req, stop_reason::OUT_OF_TOKENS);
continue;
}
req->pending_token = sampled;
req->has_pending = true;
req->i_batch = -1;
}
}
llama_batch_free(batch);
}
static void batch_start_worker_locked()
{
if(batch_worker_started)
{
return;
}
batch_worker_stop = false;
batch_worker_thread = std::thread(batch_worker_loop);
batch_worker_thread.detach();
batch_worker_started = true;
}
bool gpttype_batch_generate_enabled()
{
return continuous_batching_slots > 1 && file_format == FileFormat::GGUF_GENERIC && llama_ctx_v4 && kcpp_data;
}
int gpttype_batch_generate_submit(const generation_inputs inputs)
{
if(!batch_inputs_eligible(inputs))
{
return -1;
}
std::lock_guard<std::mutex> lock(batch_mutex);
if(batch_legacy_active || batch_legacy_waiting > 0)
{
return -1;
}
auto req = std::make_unique<BatchGenerateRequest>();
req->id = batch_next_request_id++;
req->prompt = inputs.prompt ? inputs.prompt : "";
req->max_context_length = inputs.max_context_length;
req->max_length = inputs.max_length;
req->seed = inputs.seed;
req->temperature = inputs.temperature;
req->top_k = inputs.top_k;
req->top_p = inputs.top_p;
req->min_p = inputs.min_p;
req->typical_p = inputs.typical_p;
req->rep_pen = inputs.rep_pen;
req->rep_pen_slope = inputs.rep_pen_slope;
req->rep_pen_range = inputs.rep_pen_range;
req->presence_penalty = inputs.presence_penalty;
req->allow_eos_token = inputs.allow_eos_token;
req->bypass_eos_token = inputs.bypass_eos_token;
req->render_special = inputs.render_special;
for(int i = 0; i < inputs.stop_sequence_len; ++i)
{
if(inputs.stop_sequence[i])
{
req->stop_sequences.emplace_back(inputs.stop_sequence[i]);
}
}
int request_id = req->id;
batch_requests.emplace_back(std::move(req));
batch_waiting.push_back(request_id);
batch_start_worker_locked();
batch_cv.notify_all();
return request_id;
}
bool gpttype_batch_generate_has_finished(int request_id)
{
std::lock_guard<std::mutex> lock(batch_mutex);
BatchGenerateRequest * req = batch_find_request_locked(request_id);
return !req || !batch_is_live_state(req->state);
}
int gpttype_batch_generate_stream_count(int request_id)
{
std::lock_guard<std::mutex> lock(batch_mutex);
BatchGenerateRequest * req = batch_find_request_locked(request_id);
return req ? req->generated_pieces.size() : 0;
}
const char * gpttype_batch_generate_new_token(int request_id, int idx)
{
std::lock_guard<std::mutex> lock(batch_mutex);
BatchGenerateRequest * req = batch_find_request_locked(request_id);
if(!req || idx < 0 || idx >= (int) req->generated_pieces.size())
{
return nullptr;
}
return req->generated_pieces[idx].c_str();
}
const char * gpttype_batch_generate_pending_output(int request_id)
{
std::lock_guard<std::mutex> lock(batch_mutex);
BatchGenerateRequest * req = batch_find_request_locked(request_id);
if(!req)
{
return batch_empty_string.c_str();
}
return req->output.c_str();
}
generation_outputs gpttype_batch_generate_result(int request_id)
{
std::unique_lock<std::mutex> lock(batch_mutex);
batch_cv.wait(lock, [request_id](){
BatchGenerateRequest * req = batch_find_request_locked(request_id);
return !req || !batch_is_live_state(req->state);
});
BatchGenerateRequest * req = batch_find_request_locked(request_id);
if(!req)
{
generation_outputs output;
output.status = 0;
output.stopreason = stop_reason::ERROR_ENCOUNTERED;
output.prompt_tokens = 0;
output.completion_tokens = 0;
output.text = batch_empty_string.c_str();
return output;
}
req->result.text = req->output.c_str();
return req->result;
}
bool gpttype_batch_generate_abort(int request_id)
{
std::lock_guard<std::mutex> lock(batch_mutex);
BatchGenerateRequest * req = batch_find_request_locked(request_id);
if(!req)
{
return false;
}
req->abort_requested = true;
batch_cv.notify_all();
return true;
}
void gpttype_batch_generate_release(int request_id)
{
std::lock_guard<std::mutex> lock(batch_mutex);
batch_requests.erase(std::remove_if(batch_requests.begin(), batch_requests.end(), [request_id](const std::unique_ptr<BatchGenerateRequest> & req){
return req && req->id == request_id && !batch_is_live_state(req->state);
}), batch_requests.end());
batch_cv.notify_all();
}
std::string gpttype_get_chat_template()
{
if(kcpp_data==nullptr)
@ -3547,6 +4301,7 @@ int smartcache_quick_snapshot(int specific_slot = -1)
generation_outputs gpttype_generate(const generation_inputs inputs)
{
BatchLegacyGuard batch_legacy_guard;
generation_outputs output;
if(kcpp_data==nullptr)

View file

@ -110,6 +110,9 @@ maxctx = 8192
maxhordectx = 0 #set to whatever maxctx is if 0
maxhordelen = 1024
modelbusy = threading.Lock()
batched_lock = threading.Lock()
batched_cond = threading.Condition(batched_lock)
batched_request_runner_count = 0 #incremented when a batched request is running, prevents all non-batched requests
requestsinqueue = 0
ratelimitlookup = {}
defaultport = 5001
@ -284,7 +287,8 @@ class load_model_inputs(ctypes.Structure):
("lora_multiplier", ctypes.c_float),
("devices_override", ctypes.c_char_p),
("quiet", ctypes.c_bool),
("debugmode", ctypes.c_int)]
("debugmode", ctypes.c_int),
("continuous_batching_slots", ctypes.c_int)]
class generation_inputs(ctypes.Structure):
_fields_ = [("seed", ctypes.c_int),
@ -893,6 +897,23 @@ def init_library():
handle.new_token.argtypes = [ctypes.c_int]
handle.get_stream_count.restype = ctypes.c_int
handle.has_finished.restype = ctypes.c_bool
handle.batch_generate_enabled.restype = ctypes.c_bool
handle.batch_generate_submit.argtypes = [generation_inputs]
handle.batch_generate_submit.restype = ctypes.c_int
handle.batch_generate_has_finished.argtypes = [ctypes.c_int]
handle.batch_generate_has_finished.restype = ctypes.c_bool
handle.batch_generate_stream_count.argtypes = [ctypes.c_int]
handle.batch_generate_stream_count.restype = ctypes.c_int
handle.batch_generate_new_token.argtypes = [ctypes.c_int, ctypes.c_int]
handle.batch_generate_new_token.restype = ctypes.c_char_p
handle.batch_generate_pending_output.argtypes = [ctypes.c_int]
handle.batch_generate_pending_output.restype = ctypes.c_char_p
handle.batch_generate_result.argtypes = [ctypes.c_int]
handle.batch_generate_result.restype = generation_outputs
handle.batch_generate_abort.argtypes = [ctypes.c_int]
handle.batch_generate_abort.restype = ctypes.c_bool
handle.batch_generate_release.argtypes = [ctypes.c_int]
handle.batch_generate_release.restype = None
handle.has_audio_support.restype = ctypes.c_bool
handle.has_vision_support.restype = ctypes.c_bool
handle.get_last_eval_time.restype = ctypes.c_float
@ -1912,6 +1933,9 @@ def load_model(model_filename):
inputs.visionmintokens = vmintk
inputs.visionmaxtokens = vmaxtk
inputs.use_smartcontext = args.smartcontext
if getattr(args, "continuous_batching", 0) > 1 and not args.noshift:
print("\nWarning: Continuous batching is enabled, so context shifting has been disabled automatically.\n")
args.noshift = True
inputs.use_contextshift = (0 if args.noshift else 1)
inputs.use_fastforward = (0 if args.nofastforward else 1)
inputs.flash_attention = (False if args.noflashattention else True)
@ -1984,6 +2008,7 @@ def load_model(model_filename):
savestate_limit = sclimit
inputs.smartcacheslots = sclimit
inputs.pipelineparallel = (not args.nopipelineparallel)
inputs.continuous_batching_slots = int(args.continuous_batching) if hasattr(args, "continuous_batching") else 0
inputs = set_backend_props(inputs)
ret = handle.load_model(inputs)
return ret
@ -2230,10 +2255,26 @@ def generate(genparams, stream_flag=False):
pendingabortkey = ""
return {"text":"","status":-1,"stopreason":-1, "prompt_tokens":0, "completion_tokens": 0, "total_tokens": 0}
else:
ret = handle.generate(inputs)
batch_request_id = -1
if getattr(args, "continuous_batching", 0) > 1:
try:
batch_request_id = handle.batch_generate_submit(inputs)
except Exception:
batch_request_id = -1
if batch_request_id >= 0:
genparams['_batch_request_id'] = batch_request_id
ret = handle.batch_generate_result(batch_request_id)
else:
genparams['_batch_fallback'] = True
ret = handle.generate(inputs)
outstr = ""
if ret.status==1:
outstr = ret.text.decode("UTF-8","ignore")
if batch_request_id >= 0 and not stream_flag:
handle.batch_generate_release(batch_request_id)
genparams.pop('_batch_request_id', None)
genparams.pop('_batch_expected', None)
genparams.pop('_batch_fallback', None)
if trimstop:
for trim_str in stop_sequence:
sindex = outstr.find(trim_str)
@ -2241,6 +2282,32 @@ def generate(genparams, stream_flag=False):
outstr = outstr[:sindex]
return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason,"prompt_tokens":ret.prompt_tokens, "completion_tokens": ret.completion_tokens}
def continuous_batching_python_eligible(genparams, api_format):
if getattr(args, "continuous_batching", 0) <= 1 or api_format <= 0:
return False
model_path = str(getattr(args, "model_param", "") or "").lower()
if model_path and not model_path.endswith(".gguf"):
return False
if not getattr(args, "noshift", False) or getattr(args, "smartcontext", False) or getattr(args, "draftmodel", "") or getattr(args, "mmproj", "") or getattr(args, "enableguidance", False):
return False
if genparams.get("memory") or genparams.get("negative_prompt") or genparams.get("images") or genparams.get("audio"):
return False
if genparams.get("ban_eos_token", False):
return False
if genparams.get("grammar") or genparams.get("grammar_retain_state") or genparams.get("logit_bias") or genparams.get("banned_tokens") or genparams.get("banned_strings"):
return False
if tryparsefloat(genparams.get("dry_multiplier", 0), 0) or tryparseint(genparams.get("mirostat", 0), 0) or tryparsefloat(genparams.get("xtc_probability", 0), 0) or tryparsefloat(genparams.get("nsigma", 0), 0):
return False
if tryparsefloat(genparams.get("smoothing_factor", 0), 0) or tryparsefloat(genparams.get("adaptive_target", -1), -1) > 0 or genparams.get("using_openai_tools", False):
return False
if tryparsefloat(genparams.get("top_a", 0), 0) or tryparsefloat(genparams.get("tfs", 1), 1) != 1 or tryparsefloat(genparams.get("dynatemp_range", 0), 0):
return False
if genparams.get("sampler_order") and genparams.get("sampler_order") != [6, 0, 1, 3, 4, 2, 5]:
return False
if genparams.get("reasoning_effort"):
return False
return True
def sd_get_info():
info = handle.sd_get_info()
if info.status == 0:
@ -4922,19 +4989,29 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
incomplete_token_buffer = bytearray()
async_sleep_short = 0.02
await asyncio.sleep(0.35) #anti race condition, prevent check from overtaking generate
batch_request_id = genparams.get('_batch_request_id', -1)
batch_final_result = None
try:
tokenReserve = "" #keeps fully formed tokens that we cannot send out yet
while True:
streamDone = handle.has_finished() #exit next loop on done
if batch_request_id < 0:
batch_request_id = genparams.get('_batch_request_id', -1)
if genparams.get('_batch_expected', False) and batch_request_id < 0 and not genparams.get('_batch_fallback', False):
await asyncio.sleep(async_sleep_short)
continue
using_batch_stream = batch_request_id >= 0
streamDone = handle.batch_generate_has_finished(batch_request_id) if using_batch_stream else handle.has_finished() #exit next loop on done
if streamDone:
sr = handle.get_last_stop_reason()
if using_batch_stream and batch_final_result is None:
batch_final_result = handle.batch_generate_result(batch_request_id)
sr = batch_final_result.stopreason if using_batch_stream else handle.get_last_stop_reason()
currfinishreason = "error" if sr==-2 else ("length" if (sr!=1) else "stop")
prompttokens = handle.get_last_input_count()
prompttokens = batch_final_result.prompt_tokens if using_batch_stream else handle.get_last_input_count()
tokenStr = ""
streamcount = handle.get_stream_count()
streamcount = handle.batch_generate_stream_count(batch_request_id) if using_batch_stream else handle.get_stream_count()
while current_token < streamcount:
token = handle.new_token(current_token)
token = handle.batch_generate_new_token(batch_request_id, current_token) if using_batch_stream else handle.new_token(current_token)
if token is None: # Token isnt ready yet, received nullpointer
break
@ -5147,7 +5224,8 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
if streamDone:
# content_part.done, reply full text
await asyncio.sleep(async_sleep_short)
finaltxt = handle.get_pending_output().decode("UTF-8", "ignore")
finalraw = handle.batch_generate_pending_output(batch_request_id) if using_batch_stream else handle.get_pending_output()
finaltxt = finalraw.decode("UTF-8", "ignore")
await asyncio.sleep(async_sleep_short)
done_event = json.dumps({"type": "response.output_text.done", "item_id": item_id, "output_index": 0, "sequence_number":rseq_num, "content_index": 0, "text": finaltxt})
rseq_num += 1
@ -5156,7 +5234,7 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
item_done = json.dumps({"type": "response.output_item.done", "output_index": 0, "sequence_number":rseq_num, "item": { "type": "message", "id": item_id, "status": "completed", "role": "assistant", "content": [{"type": "output_text", "text": finaltxt, "annotations": [], "logprobs": []}]}})
rseq_num += 1
await self.send_oai_responses_sse_event("response.output_item.done",item_done)
usage_pp = handle.get_last_input_count()
usage_pp = batch_final_result.prompt_tokens if using_batch_stream else handle.get_last_input_count()
usage_gen = current_token
res = self.prepare_basic_responses_body(resp_id,genparams)
res["completed_at"] = int(time.time())
@ -5204,8 +5282,17 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
except Exception as ex:
print("Token streaming was interrupted or aborted!")
print(ex)
handle.abort_generate()
if batch_request_id >= 0:
handle.batch_generate_abort(batch_request_id)
else:
handle.abort_generate()
await asyncio.sleep(0.2) #short delay
finally:
if batch_request_id >= 0:
handle.batch_generate_release(batch_request_id)
genparams.pop('_batch_request_id', None)
genparams.pop('_batch_expected', None)
genparams.pop('_batch_fallback', None)
# flush buffers, sleep a bit to make sure all data sent, and then force close the connection
self.wfile.flush()
@ -5271,7 +5358,7 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
pass
def get_multiplayer_idle_state(self,userid):
if modelbusy.locked():
if modelbusy.locked() or batched_request_runner_count>0:
return False
for key, value in multiplayer_lastactive.items():
if key!=userid and time.time()-value<6: #6s to idle
@ -5308,7 +5395,7 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
return True
def noscript_webui(self):
global modelbusy, sslvalid
global modelbusy, sslvalid, batched_request_runner_count
parsed_url = urllib.parse.urlparse(self.path)
parsed_dict = urllib.parse.parse_qs(parsed_url.query)
reply = ""
@ -5348,7 +5435,7 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
else:
gencommand = False
if modelbusy.locked():
if modelbusy.locked() or batched_request_runner_count>0:
status = "Model is currently busy, try again later."
elif gencommand:
if prompt=="" or max_length<=0:
@ -5587,7 +5674,7 @@ Change Mode<br>
"total_tts_gens": totalttsgens,
"total_transcribe_gens": totaltranscribegens,
"queue": requestsinqueue,
"idle": (0 if modelbusy.locked() else 1),
"idle": (0 if (modelbusy.locked() or batched_request_runner_count>0) else 1),
"hordeexitcounter": exitcounter,
"uptime": uptime,
"idletime": idletime,
@ -5850,7 +5937,7 @@ Change Mode<br>
def do_POST(self):
global thinkformats
global modelbusy, requestsinqueue, currentusergenkey, totalgens, pendingabortkey, lastuploadedcomfyimg, lastgeneratedcomfyimg, multiplayer_turn_major, multiplayer_turn_minor, multiplayer_story_data_compressed, multiplayer_dataformat, multiplayer_lastactive, net_save_slots, has_vision_support, savestate_limit, mcp_lock
global modelbusy, batched_request_runner_count, requestsinqueue, currentusergenkey, totalgens, pendingabortkey, lastuploadedcomfyimg, lastgeneratedcomfyimg, multiplayer_turn_major, multiplayer_turn_minor, multiplayer_story_data_compressed, multiplayer_dataformat, multiplayer_lastactive, net_save_slots, has_vision_support, savestate_limit, mcp_lock
global autoswapmode, textName, sttName, ttsName, embedName, musicName, imageName, mmprojName
contlenstr = self.headers['content-length']
content_length = 0
@ -6323,6 +6410,7 @@ Change Mode<br>
"type": "service_unavailable",
}}).encode())
return
is_batchable_req = False
if reqblocking:
requestsinqueue = (requestsinqueue - 1) if requestsinqueue > 0 else 0
@ -6531,10 +6619,22 @@ Change Mode<br>
if args.foreground:
bring_terminal_to_foreground()
#if it's a non-batchable request and we already have batching ongoing, stall this request
if batched_request_runner_count > 0 and not continuous_batching_python_eligible(genparams, api_format):
with batched_cond:
while batched_request_runner_count > 0:
batched_cond.wait()
if api_format > 0: #text gen
# Check if streaming chat completions, if so, set stream mode to true
if (api_format == 4 or api_format == 3 or api_format == 8 or api_format == 9) and "stream" in genparams and genparams["stream"]:
sse_stream_flag = True
if continuous_batching_python_eligible(genparams, api_format):
genparams['_batch_expected'] = True
modelbusy.release()
is_batchable_req = True
with batched_cond:
batched_request_runner_count += 1
gendat = asyncio.run(self.handle_request(genparams, api_format, sse_stream_flag))
@ -6887,7 +6987,12 @@ Change Mode<br>
finally:
time.sleep(0.05)
modelbusy.release()
if is_batchable_req:
with batched_cond:
batched_request_runner_count -= 1
batched_cond.notify_all()
else:
modelbusy.release()
self.send_response(404)
self.end_headers(content_type='text/html')
@ -11380,6 +11485,7 @@ if __name__ == '__main__':
advparser.add_argument("--analyze", metavar=('[filename]'), help="Reads the metadata, weight types and tensor names in any GGUF file.", default="")
advparser.add_argument("--maingpu","--main-gpu","-mg", help="Only used in a multi-gpu setup. Sets the index of the main GPU that will be used.",metavar=('[Device ID]'), type=int, default=-1)
advparser.add_argument("--batchsize","--blasbatchsize","--batch-size","-b", help="Sets the batch size used in batched processing (default 512). Setting it to -1 disables batched mode, but keeps other benefits like GPU offload.", type=int,choices=[-1,16,32,64,128,256,512,1024,2048,4096], default=512)
advparser.add_argument("--continuous-batching","--contbatch", help=argparse.SUPPRESS, metavar=('[slots]'), type=check_range(int,0,64), default=0)
advparser.add_argument("--blasthreads","--batchthreads","--threadsbatch","--threads-batch", help="Use a different number of threads during batching if specified. Otherwise, has the same value as --threads",metavar=('[threads]'), type=int, default=0)
advparser.add_argument("--splitmode","-sm","--split-mode", help="How to split the model across multiple GPUs", metavar=('[split mode]'), type=str, choices=splitmode_choices, default=splitmode_choices[0])
advparser.add_argument("--nommq", help="Disables MMQ, only used for cuda backend. This flag may be removed in future.", action='store_true')

View file

@ -84,6 +84,15 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
generation_outputs gpttype_generate(const generation_inputs inputs);
bool gpttype_generate_abort();
std::string gpttype_get_chat_template();
bool gpttype_batch_generate_enabled();
int gpttype_batch_generate_submit(const generation_inputs inputs);
bool gpttype_batch_generate_has_finished(int request_id);
int gpttype_batch_generate_stream_count(int request_id);
const char * gpttype_batch_generate_new_token(int request_id, int idx);
const char * gpttype_batch_generate_pending_output(int request_id);
generation_outputs gpttype_batch_generate_result(int request_id);
bool gpttype_batch_generate_abort(int request_id);
void gpttype_batch_generate_release(int request_id);
const std::string & gpttype_get_pending_output();
std::vector<int> gpttype_get_token_arr(const std::string & input, bool addbos);