diff --git a/expose.cpp b/expose.cpp index 346eb14ca..0d42e2a57 100644 --- a/expose.cpp +++ b/expose.cpp @@ -258,6 +258,33 @@ extern "C" bool has_finished() { return generation_finished; } + bool batch_generate_enabled() { + return gpttype_batch_generate_enabled(); + } + int batch_generate_submit(const generation_inputs inputs) { + return gpttype_batch_generate_submit(inputs); + } + bool batch_generate_has_finished(int request_id) { + return gpttype_batch_generate_has_finished(request_id); + } + int batch_generate_stream_count(int request_id) { + return gpttype_batch_generate_stream_count(request_id); + } + const char * batch_generate_new_token(int request_id, int idx) { + return gpttype_batch_generate_new_token(request_id, idx); + } + const char * batch_generate_pending_output(int request_id) { + return gpttype_batch_generate_pending_output(request_id); + } + generation_outputs batch_generate_result(int request_id) { + return gpttype_batch_generate_result(request_id); + } + bool batch_generate_abort(int request_id) { + return gpttype_batch_generate_abort(request_id); + } + void batch_generate_release(int request_id) { + gpttype_batch_generate_release(request_id); + } bool has_audio_support() { return audio_multimodal_supported; diff --git a/expose.h b/expose.h index f44c327ee..e4ee7ff6c 100644 --- a/expose.h +++ b/expose.h @@ -84,6 +84,7 @@ struct load_model_inputs const char * devices_override = nullptr; const bool quiet = false; const int debugmode = 0; + const int continuous_batching_slots = 0; }; struct generation_inputs { @@ -385,3 +386,13 @@ extern int total_transcribe_gens; extern int last_draft_success; extern int last_draft_failed; extern stop_reason last_stop_reason; + +bool gpttype_batch_generate_enabled(); +int gpttype_batch_generate_submit(const generation_inputs inputs); +bool gpttype_batch_generate_has_finished(int request_id); +int gpttype_batch_generate_stream_count(int request_id); +const char * gpttype_batch_generate_new_token(int request_id, int idx); +const char * gpttype_batch_generate_pending_output(int request_id); +generation_outputs gpttype_batch_generate_result(int request_id); +bool gpttype_batch_generate_abort(int request_id); +void gpttype_batch_generate_release(int request_id); diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index d55e1b7f1..9bd4d3d07 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -22,6 +22,11 @@ #include #include #include +#include +#include +#include +#include +#include #include "utils.h" #include "llmutils.h" @@ -76,6 +81,7 @@ int last_draft_success = 0; int last_draft_failed = 0; stop_reason last_stop_reason = stop_reason::INVALID; std::vector generated_tokens; +static int continuous_batching_slots = 0; llama_grammar * grammar = nullptr; //currently used grammar llama_grammar_parser parsed_grammar; @@ -2197,6 +2203,11 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in kcpp_pipeline_parallelism = inputs.pipelineparallel; kcpp_data->n_batch = GetBatchSize(inputs.batchsize, in_file_format); kcpp_data->n_ubatch = kcpp_data->n_batch; + continuous_batching_slots = (isGguf && inputs.continuous_batching_slots > 1) ? inputs.continuous_batching_slots : 0; + if(continuous_batching_slots > 0) + { + printf("Continuous batching: prepared %d GGUF sequence slots.\n", continuous_batching_slots); + } kcpp_data->vision_min_tokens = inputs.visionmintokens; kcpp_data->vision_max_tokens = inputs.visionmaxtokens; if(isGguf && kcpp_pipeline_parallelism) @@ -2461,6 +2472,10 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in llama_ctx_params.n_batch = kcpp_data->n_batch; llama_ctx_params.n_ubatch = kcpp_data->n_ubatch; + if(continuous_batching_slots > 0) + { + llama_ctx_params.n_seq_max = continuous_batching_slots + 1; + } llama_ctx_params.n_threads = kcpp_data->n_threads; llama_ctx_params.n_threads_batch = kcpp_data->n_blasthreads; @@ -3269,6 +3284,745 @@ bool gpttype_generate_abort() return true; } +enum class BatchState +{ + WAITING, + PREFILL, + GENERATING, + FINISHED, + FAILED, + ABORTED, +}; + +struct BatchGenerateRequest +{ + int id = 0; + int slot = -1; + BatchState state = BatchState::WAITING; + std::string prompt; + std::vector stop_sequences; + int max_context_length = 0; + int max_length = 0; + int seed = 0; + float temperature = 0.0f; + int top_k = 0; + float top_p = 1.0f; + float min_p = 0.0f; + float typical_p = 1.0f; + float rep_pen = 1.0f; + float rep_pen_slope = 1.0f; + int rep_pen_range = 0; + float presence_penalty = 0.0f; + bool allow_eos_token = true; + bool bypass_eos_token = false; + bool render_special = false; + std::vector prompt_tokens; + int prompt_pos = 0; + int n_past = 0; + bool has_pending = false; + llama_token pending_token = 0; + int i_batch = -1; + llama_sampler * sampler = nullptr; + std::vector generated_pieces; + std::string output; + int prompt_token_count = 0; + int completion_token_count = 0; + std::chrono::steady_clock::time_point start_time; + stop_reason finish_reason = stop_reason::INVALID; + bool abort_requested = false; + generation_outputs result; + + ~BatchGenerateRequest() + { + if(sampler) + { + llama_sampler_free(sampler); + sampler = nullptr; + } + } +}; + +static std::mutex batch_mutex; +static std::condition_variable batch_cv; +static std::deque batch_waiting; +static std::vector> batch_requests; +static std::thread batch_worker_thread; +static bool batch_worker_stop = false; +static bool batch_worker_started = false; +static bool batch_legacy_active = false; +static bool batch_touched_since_legacy = false; +static int batch_legacy_waiting = 0; +static int batch_next_request_id = 1; +static std::string batch_empty_string = ""; + +static BatchGenerateRequest * batch_find_request_locked(int request_id) +{ + for(auto & req : batch_requests) + { + if(req && req->id == request_id) + { + return req.get(); + } + } + return nullptr; +} + +static bool batch_is_live_state(BatchState state) +{ + return state == BatchState::WAITING || state == BatchState::PREFILL || state == BatchState::GENERATING; +} + +static bool batch_has_live_locked() +{ + for(const auto & req : batch_requests) + { + if(req && batch_is_live_state(req->state)) + { + return true; + } + } + return false; +} + +static void batch_invalidate_legacy_context_locked() +{ + if(!batch_touched_since_legacy) + { + return; + } + batch_touched_since_legacy = false; + n_past = 0; + current_context_tokens.clear(); + last_n_tokens.clear(); + smartcontext.clear(); + loaded_latest_logits.clear(); + if(llama_ctx_v4) + { + llama_memory_seq_rm(llama_get_memory(llama_ctx_v4), 0, -1, -1); + } + if(draft_ctx) + { + llama_memory_seq_rm(llama_get_memory(draft_ctx), 0, -1, -1); + } + if(debugmode==1 && !is_quiet) + { + printf("\n[Continuous batching touched shared context; forcing next legacy generation to reprocess prompt]\n"); + } +} + +class BatchLegacyGuard +{ +public: + BatchLegacyGuard() + { + std::unique_lock lock(batch_mutex); + batch_legacy_waiting++; + batch_cv.notify_all(); + batch_cv.wait(lock, [](){ return !batch_has_live_locked(); }); + batch_legacy_waiting--; + batch_invalidate_legacy_context_locked(); + batch_legacy_active = true; + } + + ~BatchLegacyGuard() + { + std::lock_guard lock(batch_mutex); + batch_legacy_active = false; + batch_cv.notify_all(); + } +}; + +static bool batch_inputs_eligible(const generation_inputs & inputs) +{ + if(continuous_batching_slots <= 1 || file_format != FileFormat::GGUF_GENERIC || !llama_ctx_v4 || !kcpp_data) + { + return false; + } + if(draft_ctx || guidance_ctx || clp_ctx_v || clp_ctx_a) + { + return false; + } + if(kcpp_data->use_smartcontext || kcpp_data->use_contextshift || kcpp_data->smartcache) + { + return false; + } + if(inputs.memory && std::string(inputs.memory).size() > 0) + { + return false; + } + if(inputs.negative_prompt && std::string(inputs.negative_prompt).size() > 0) + { + return false; + } + if(inputs.images_len > 0 || inputs.audio_len > 0 || inputs.guidance_scale != 1.0f || !inputs.allow_eos_token) + { + return false; + } + if(inputs.grammar && std::string(inputs.grammar).size() > 0) + { + return false; + } + if(inputs.logit_biases_len > 0 || inputs.banned_tokens_len > 0 || inputs.dry_multiplier > 0.0f) + { + return false; + } + if(inputs.mirostat != 0 || inputs.xtc_probability > 0.0f || inputs.nsigma > 0.0f || inputs.smoothing_factor > 0.0f || inputs.adaptive_target > 0.0f) + { + return false; + } + if(inputs.top_a > 0.0f || inputs.tfs != 1.0f || inputs.dynatemp_range > 0.0f) + { + return false; + } + static const int default_sampler_order[] = {6, 0, 1, 3, 4, 2, 5}; + if(inputs.sampler_len > 0) + { + if(inputs.sampler_len != 7) + { + return false; + } + for(int i = 0; i < 7; ++i) + { + if((int) inputs.sampler_order[i] != default_sampler_order[i]) + { + return false; + } + } + } + if(inputs.reasoning_budget >= 0 || inputs.tool_call_fix) + { + return false; + } + return true; +} + +struct BatchRepPenSampler +{ + int32_t penalty_last_n = 0; + float penalty_repeat = 1.0f; + float penalty_slope = 1.0f; + float penalty_present = 0.0f; + std::vector prev; +}; + +static const char * batch_rep_pen_name(const llama_sampler * /*smpl*/) +{ + return "kcpp-batch-rep-pen"; +} + +static void batch_rep_pen_accept(llama_sampler * smpl, llama_token token) +{ + auto * ctx = (BatchRepPenSampler *) smpl->ctx; + if(ctx->penalty_last_n <= 0) + { + return; + } + if(ctx->prev.size() >= (size_t) ctx->penalty_last_n) + { + ctx->prev.erase(ctx->prev.begin()); + } + ctx->prev.push_back(token); +} + +static void batch_rep_pen_apply(llama_sampler * smpl, llama_token_data_array * cur_p) +{ + auto * ctx = (BatchRepPenSampler *) smpl->ctx; + int last_n_repeat = std::min((int) ctx->prev.size(), ctx->penalty_last_n); + if(last_n_repeat <= 0 || (ctx->penalty_repeat == 1.0f && ctx->penalty_present == 0.0f)) + { + return; + } + + const llama_token * last_tokens = ctx->prev.data() + ctx->prev.size() - last_n_repeat; + std::unordered_set tokens_near(last_tokens + last_n_repeat / 2, last_tokens + last_n_repeat); + std::unordered_set tokens_far(last_tokens, last_tokens + last_n_repeat / 2); + + float penalty_reduced = ctx->penalty_repeat; + if(penalty_reduced > 1.0f) + { + penalty_reduced = 1.0f + ((ctx->penalty_repeat - 1.0f) * ctx->penalty_slope); + } + + for(size_t i = 0; i < cur_p->size; ++i) + { + const bool token_in_near = tokens_near.find(cur_p->data[i].id) != tokens_near.end(); + const bool token_in_far = tokens_far.find(cur_p->data[i].id) != tokens_far.end(); + if(!token_in_near && !token_in_far) + { + continue; + } + + float penalty = token_in_near ? ctx->penalty_repeat : penalty_reduced; + if(cur_p->data[i].logit <= 0) + { + cur_p->data[i].logit *= penalty; + } + else + { + cur_p->data[i].logit /= penalty; + } + cur_p->data[i].logit -= ctx->penalty_present; + } + + cur_p->sorted = false; +} + +static void batch_rep_pen_reset(llama_sampler * smpl) +{ + auto * ctx = (BatchRepPenSampler *) smpl->ctx; + ctx->prev.clear(); +} + +static llama_sampler * batch_rep_pen_clone(const llama_sampler * smpl) +{ + const auto * ctx = (const BatchRepPenSampler *) smpl->ctx; + auto * result = llama_sampler_init(smpl->iface, new BatchRepPenSampler { + ctx->penalty_last_n, + ctx->penalty_repeat, + ctx->penalty_slope, + ctx->penalty_present, + ctx->prev, + }); + return result; +} + +static void batch_rep_pen_free(llama_sampler * smpl) +{ + delete (BatchRepPenSampler *) smpl->ctx; +} + +static llama_sampler_i batch_rep_pen_i = { + /* .name = */ batch_rep_pen_name, + /* .accept = */ batch_rep_pen_accept, + /* .apply = */ batch_rep_pen_apply, + /* .reset = */ batch_rep_pen_reset, + /* .clone = */ batch_rep_pen_clone, + /* .free = */ batch_rep_pen_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, +}; + +static llama_sampler * batch_rep_pen_init(int32_t penalty_last_n, float penalty_repeat, float penalty_slope, float penalty_present) +{ + penalty_last_n = std::max(penalty_last_n, 0); + if(penalty_slope <= 0.0f || penalty_slope > 1.0f) + { + penalty_slope = 1.0f; + } + return llama_sampler_init(&batch_rep_pen_i, new BatchRepPenSampler { + penalty_last_n, + penalty_repeat <= 0.0f ? 1.0f : penalty_repeat, + penalty_slope, + penalty_present, + {}, + }); +} + +static llama_sampler * batch_build_sampler(const BatchGenerateRequest & req) +{ + llama_sampler_chain_params params = llama_sampler_chain_default_params(); + llama_sampler * chain = llama_sampler_chain_init(params); + llama_sampler_chain_add(chain, batch_rep_pen_init( + req.rep_pen_range, + req.rep_pen, + req.rep_pen_slope, + req.presence_penalty)); + if(req.top_k > 0) + { + llama_sampler_chain_add(chain, llama_sampler_init_top_k(req.top_k)); + } + if(req.top_p > 0.0f && req.top_p < 1.0f) + { + llama_sampler_chain_add(chain, llama_sampler_init_top_p(req.top_p, 1)); + } + if(req.min_p > 0.0f) + { + llama_sampler_chain_add(chain, llama_sampler_init_min_p(req.min_p, 1)); + } + if(req.typical_p > 0.0f && req.typical_p < 1.0f) + { + llama_sampler_chain_add(chain, llama_sampler_init_typical(req.typical_p, 1)); + } + if(req.temperature > 0.0f) + { + llama_sampler_chain_add(chain, llama_sampler_init_temp(req.temperature)); + llama_sampler_chain_add(chain, llama_sampler_init_dist(req.seed < 0 ? LLAMA_DEFAULT_SEED : (uint32_t) req.seed)); + } + else + { + llama_sampler_chain_add(chain, llama_sampler_init_greedy()); + } + return chain; +} + +static void batch_finish_request_locked(BatchGenerateRequest & req, stop_reason reason) +{ + auto finish_time = std::chrono::steady_clock::now(); + float total_time = req.start_time.time_since_epoch().count() == 0 ? 0.0f : std::chrono::duration(finish_time - req.start_time).count(); + float generated_tps = total_time > 0.0f ? (float) req.completion_token_count / total_time : 0.0f; + req.finish_reason = reason; + req.result.status = (reason == stop_reason::ERROR_ENCOUNTERED) ? 0 : 1; + req.result.stopreason = reason; + req.result.prompt_tokens = req.prompt_token_count; + req.result.completion_tokens = req.completion_token_count; + req.result.text = req.output.c_str(); + req.state = reason == stop_reason::ERROR_ENCOUNTERED ? BatchState::FAILED : (reason == stop_reason::INVALID ? BatchState::ABORTED : BatchState::FINISHED); + if(req.slot >= 0 && llama_ctx_v4) + { + llama_memory_seq_rm(llama_get_memory(llama_ctx_v4), req.slot, -1, -1); + } + req.slot = -1; + printf("\n[%s] BatchRequest:%d, Prompt:%d, Generated:%d/%d in %.2fs (%.2fT/s), Stop:%d", + get_timestamp_str().c_str(), req.id, req.prompt_token_count, req.completion_token_count, req.max_length, total_time, generated_tps, (int) reason); + fflush(stdout); + batch_cv.notify_all(); +} + +static bool batch_output_hit_stop(const BatchGenerateRequest & req) +{ + for(const auto & stopper : req.stop_sequences) + { + if(!stopper.empty() && req.output.find(stopper) != std::string::npos) + { + return true; + } + } + return false; +} + +static bool batch_claim_waiting_locked() +{ + bool claimed = false; + for(int slot = 1; slot <= continuous_batching_slots && !batch_waiting.empty(); ++slot) + { + bool occupied = false; + for(const auto & req : batch_requests) + { + if(req && req->slot == slot && batch_is_live_state(req->state)) + { + occupied = true; + break; + } + } + if(occupied) + { + continue; + } + int request_id = batch_waiting.front(); + batch_waiting.pop_front(); + BatchGenerateRequest * req = batch_find_request_locked(request_id); + if(!req || req->state != BatchState::WAITING) + { + continue; + } + req->slot = slot; + req->state = BatchState::PREFILL; + batch_touched_since_legacy = true; + TokenizeString(req->prompt, req->prompt_tokens, file_format, add_bos_token); + if(req->prompt_tokens.empty()) + { + TokenizeString("", req->prompt_tokens, file_format, add_bos_token); + } + int n_ctx = req->max_context_length > 0 ? std::min(req->max_context_length, kcpp_data->n_ctx) : kcpp_data->n_ctx; + if(req->max_length > 0 && (int) req->prompt_tokens.size() + req->max_length > n_ctx) + { + int keep = std::max(1, n_ctx - req->max_length); + if((int) req->prompt_tokens.size() > keep) + { + req->prompt_tokens.erase(req->prompt_tokens.begin(), req->prompt_tokens.end() - keep); + } + } + req->prompt_token_count = req->prompt_tokens.size(); + req->sampler = batch_build_sampler(*req); + for(llama_token token : req->prompt_tokens) + { + llama_sampler_accept(req->sampler, token); + } + req->prompt_pos = 0; + req->n_past = 0; + req->has_pending = false; + req->i_batch = -1; + req->start_time = std::chrono::steady_clock::now(); + llama_memory_seq_rm(llama_get_memory(llama_ctx_v4), slot, -1, -1); + claimed = true; + } + return claimed; +} + +static void batch_worker_loop() +{ + const int batch_cap = std::max(1, kcpp_data ? kcpp_data->n_batch : 512); + llama_batch batch = llama_batch_init(batch_cap, 0, 1); + while(true) + { + std::vector decode_ids; + { + std::unique_lock lock(batch_mutex); + batch_cv.wait_for(lock, std::chrono::milliseconds(5), [](){ + return batch_worker_stop || (!batch_legacy_active && batch_has_live_locked()); + }); + if(batch_worker_stop) + { + break; + } + if(batch_legacy_active) + { + continue; + } + batch_claim_waiting_locked(); + common_batch_clear(batch); + for(auto & req_ptr : batch_requests) + { + if(!req_ptr || !batch_is_live_state(req_ptr->state) || req_ptr->slot < 0 || batch.n_tokens >= batch_cap) + { + continue; + } + BatchGenerateRequest & req = *req_ptr; + req.i_batch = -1; + if(req.abort_requested) + { + batch_finish_request_locked(req, stop_reason::INVALID); + continue; + } + if(req.state == BatchState::PREFILL) + { + while(req.prompt_pos < (int) req.prompt_tokens.size() && batch.n_tokens < batch_cap) + { + bool is_last = req.prompt_pos == (int) req.prompt_tokens.size() - 1; + if(is_last) + { + req.i_batch = batch.n_tokens; + } + common_batch_add(batch, req.prompt_tokens[req.prompt_pos], req.n_past, { req.slot }, is_last); + req.prompt_pos++; + req.n_past++; + } + if(req.prompt_pos == (int) req.prompt_tokens.size()) + { + req.state = BatchState::GENERATING; + } + } + else if(req.state == BatchState::GENERATING && req.has_pending) + { + req.i_batch = batch.n_tokens; + common_batch_add(batch, req.pending_token, req.n_past, { req.slot }, true); + req.n_past++; + req.has_pending = false; + } + } + if(batch.n_tokens == 0) + { + continue; + } + for(auto & req_ptr : batch_requests) + { + if(req_ptr && req_ptr->i_batch >= 0) + { + decode_ids.push_back(req_ptr->id); + } + } + } + + int decode_status = llama_decode(llama_ctx_v4, batch); + + std::lock_guard lock(batch_mutex); + if(decode_status != 0) + { + for(int request_id : decode_ids) + { + BatchGenerateRequest * req = batch_find_request_locked(request_id); + if(req && batch_is_live_state(req->state)) + { + batch_finish_request_locked(*req, stop_reason::ERROR_ENCOUNTERED); + } + } + continue; + } + + const llama_vocab * vocab = llama_model_get_vocab(llama_get_model(llama_ctx_v4)); + llama_token eos = llama_vocab_eos(vocab); + for(int request_id : decode_ids) + { + BatchGenerateRequest * req = batch_find_request_locked(request_id); + if(!req || req->state != BatchState::GENERATING || req->i_batch < 0) + { + continue; + } + llama_token sampled = llama_sampler_sample(req->sampler, llama_ctx_v4, req->i_batch); + if(!req->allow_eos_token && !req->bypass_eos_token && sampled == eos) + { + sampled = llama_sampler_sample(req->sampler, llama_ctx_v4, req->i_batch); + } + req->completion_token_count++; + if(sampled == eos && !req->bypass_eos_token) + { + batch_finish_request_locked(*req, stop_reason::EOS_TOKEN_HIT); + continue; + } + std::string piece = FileFormatTokenizeID(sampled, file_format, req->render_special); + req->generated_pieces.push_back(piece); + req->output += piece; + if(batch_output_hit_stop(*req)) + { + batch_finish_request_locked(*req, stop_reason::CUSTOM_STOPPER); + continue; + } + if(req->max_length > 0 && req->completion_token_count >= req->max_length) + { + batch_finish_request_locked(*req, stop_reason::OUT_OF_TOKENS); + continue; + } + req->pending_token = sampled; + req->has_pending = true; + req->i_batch = -1; + } + } + llama_batch_free(batch); +} + +static void batch_start_worker_locked() +{ + if(batch_worker_started) + { + return; + } + batch_worker_stop = false; + batch_worker_thread = std::thread(batch_worker_loop); + batch_worker_thread.detach(); + batch_worker_started = true; +} + +bool gpttype_batch_generate_enabled() +{ + return continuous_batching_slots > 1 && file_format == FileFormat::GGUF_GENERIC && llama_ctx_v4 && kcpp_data; +} + +int gpttype_batch_generate_submit(const generation_inputs inputs) +{ + if(!batch_inputs_eligible(inputs)) + { + return -1; + } + std::lock_guard lock(batch_mutex); + if(batch_legacy_active || batch_legacy_waiting > 0) + { + return -1; + } + auto req = std::make_unique(); + req->id = batch_next_request_id++; + req->prompt = inputs.prompt ? inputs.prompt : ""; + req->max_context_length = inputs.max_context_length; + req->max_length = inputs.max_length; + req->seed = inputs.seed; + req->temperature = inputs.temperature; + req->top_k = inputs.top_k; + req->top_p = inputs.top_p; + req->min_p = inputs.min_p; + req->typical_p = inputs.typical_p; + req->rep_pen = inputs.rep_pen; + req->rep_pen_slope = inputs.rep_pen_slope; + req->rep_pen_range = inputs.rep_pen_range; + req->presence_penalty = inputs.presence_penalty; + req->allow_eos_token = inputs.allow_eos_token; + req->bypass_eos_token = inputs.bypass_eos_token; + req->render_special = inputs.render_special; + for(int i = 0; i < inputs.stop_sequence_len; ++i) + { + if(inputs.stop_sequence[i]) + { + req->stop_sequences.emplace_back(inputs.stop_sequence[i]); + } + } + int request_id = req->id; + batch_requests.emplace_back(std::move(req)); + batch_waiting.push_back(request_id); + batch_start_worker_locked(); + batch_cv.notify_all(); + return request_id; +} + +bool gpttype_batch_generate_has_finished(int request_id) +{ + std::lock_guard lock(batch_mutex); + BatchGenerateRequest * req = batch_find_request_locked(request_id); + return !req || !batch_is_live_state(req->state); +} + +int gpttype_batch_generate_stream_count(int request_id) +{ + std::lock_guard lock(batch_mutex); + BatchGenerateRequest * req = batch_find_request_locked(request_id); + return req ? req->generated_pieces.size() : 0; +} + +const char * gpttype_batch_generate_new_token(int request_id, int idx) +{ + std::lock_guard lock(batch_mutex); + BatchGenerateRequest * req = batch_find_request_locked(request_id); + if(!req || idx < 0 || idx >= (int) req->generated_pieces.size()) + { + return nullptr; + } + return req->generated_pieces[idx].c_str(); +} + +const char * gpttype_batch_generate_pending_output(int request_id) +{ + std::lock_guard lock(batch_mutex); + BatchGenerateRequest * req = batch_find_request_locked(request_id); + if(!req) + { + return batch_empty_string.c_str(); + } + return req->output.c_str(); +} + +generation_outputs gpttype_batch_generate_result(int request_id) +{ + std::unique_lock lock(batch_mutex); + batch_cv.wait(lock, [request_id](){ + BatchGenerateRequest * req = batch_find_request_locked(request_id); + return !req || !batch_is_live_state(req->state); + }); + BatchGenerateRequest * req = batch_find_request_locked(request_id); + if(!req) + { + generation_outputs output; + output.status = 0; + output.stopreason = stop_reason::ERROR_ENCOUNTERED; + output.prompt_tokens = 0; + output.completion_tokens = 0; + output.text = batch_empty_string.c_str(); + return output; + } + req->result.text = req->output.c_str(); + return req->result; +} + +bool gpttype_batch_generate_abort(int request_id) +{ + std::lock_guard lock(batch_mutex); + BatchGenerateRequest * req = batch_find_request_locked(request_id); + if(!req) + { + return false; + } + req->abort_requested = true; + batch_cv.notify_all(); + return true; +} + +void gpttype_batch_generate_release(int request_id) +{ + std::lock_guard lock(batch_mutex); + batch_requests.erase(std::remove_if(batch_requests.begin(), batch_requests.end(), [request_id](const std::unique_ptr & req){ + return req && req->id == request_id && !batch_is_live_state(req->state); + }), batch_requests.end()); + batch_cv.notify_all(); +} + std::string gpttype_get_chat_template() { if(kcpp_data==nullptr) @@ -3547,6 +4301,7 @@ int smartcache_quick_snapshot(int specific_slot = -1) generation_outputs gpttype_generate(const generation_inputs inputs) { + BatchLegacyGuard batch_legacy_guard; generation_outputs output; if(kcpp_data==nullptr) diff --git a/koboldcpp.py b/koboldcpp.py index 4c1cf4088..5972f0f9d 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -110,6 +110,9 @@ maxctx = 8192 maxhordectx = 0 #set to whatever maxctx is if 0 maxhordelen = 1024 modelbusy = threading.Lock() +batched_lock = threading.Lock() +batched_cond = threading.Condition(batched_lock) +batched_request_runner_count = 0 #incremented when a batched request is running, prevents all non-batched requests requestsinqueue = 0 ratelimitlookup = {} defaultport = 5001 @@ -284,7 +287,8 @@ class load_model_inputs(ctypes.Structure): ("lora_multiplier", ctypes.c_float), ("devices_override", ctypes.c_char_p), ("quiet", ctypes.c_bool), - ("debugmode", ctypes.c_int)] + ("debugmode", ctypes.c_int), + ("continuous_batching_slots", ctypes.c_int)] class generation_inputs(ctypes.Structure): _fields_ = [("seed", ctypes.c_int), @@ -893,6 +897,23 @@ def init_library(): handle.new_token.argtypes = [ctypes.c_int] handle.get_stream_count.restype = ctypes.c_int handle.has_finished.restype = ctypes.c_bool + handle.batch_generate_enabled.restype = ctypes.c_bool + handle.batch_generate_submit.argtypes = [generation_inputs] + handle.batch_generate_submit.restype = ctypes.c_int + handle.batch_generate_has_finished.argtypes = [ctypes.c_int] + handle.batch_generate_has_finished.restype = ctypes.c_bool + handle.batch_generate_stream_count.argtypes = [ctypes.c_int] + handle.batch_generate_stream_count.restype = ctypes.c_int + handle.batch_generate_new_token.argtypes = [ctypes.c_int, ctypes.c_int] + handle.batch_generate_new_token.restype = ctypes.c_char_p + handle.batch_generate_pending_output.argtypes = [ctypes.c_int] + handle.batch_generate_pending_output.restype = ctypes.c_char_p + handle.batch_generate_result.argtypes = [ctypes.c_int] + handle.batch_generate_result.restype = generation_outputs + handle.batch_generate_abort.argtypes = [ctypes.c_int] + handle.batch_generate_abort.restype = ctypes.c_bool + handle.batch_generate_release.argtypes = [ctypes.c_int] + handle.batch_generate_release.restype = None handle.has_audio_support.restype = ctypes.c_bool handle.has_vision_support.restype = ctypes.c_bool handle.get_last_eval_time.restype = ctypes.c_float @@ -1912,6 +1933,9 @@ def load_model(model_filename): inputs.visionmintokens = vmintk inputs.visionmaxtokens = vmaxtk inputs.use_smartcontext = args.smartcontext + if getattr(args, "continuous_batching", 0) > 1 and not args.noshift: + print("\nWarning: Continuous batching is enabled, so context shifting has been disabled automatically.\n") + args.noshift = True inputs.use_contextshift = (0 if args.noshift else 1) inputs.use_fastforward = (0 if args.nofastforward else 1) inputs.flash_attention = (False if args.noflashattention else True) @@ -1984,6 +2008,7 @@ def load_model(model_filename): savestate_limit = sclimit inputs.smartcacheslots = sclimit inputs.pipelineparallel = (not args.nopipelineparallel) + inputs.continuous_batching_slots = int(args.continuous_batching) if hasattr(args, "continuous_batching") else 0 inputs = set_backend_props(inputs) ret = handle.load_model(inputs) return ret @@ -2230,10 +2255,26 @@ def generate(genparams, stream_flag=False): pendingabortkey = "" return {"text":"","status":-1,"stopreason":-1, "prompt_tokens":0, "completion_tokens": 0, "total_tokens": 0} else: - ret = handle.generate(inputs) + batch_request_id = -1 + if getattr(args, "continuous_batching", 0) > 1: + try: + batch_request_id = handle.batch_generate_submit(inputs) + except Exception: + batch_request_id = -1 + if batch_request_id >= 0: + genparams['_batch_request_id'] = batch_request_id + ret = handle.batch_generate_result(batch_request_id) + else: + genparams['_batch_fallback'] = True + ret = handle.generate(inputs) outstr = "" if ret.status==1: outstr = ret.text.decode("UTF-8","ignore") + if batch_request_id >= 0 and not stream_flag: + handle.batch_generate_release(batch_request_id) + genparams.pop('_batch_request_id', None) + genparams.pop('_batch_expected', None) + genparams.pop('_batch_fallback', None) if trimstop: for trim_str in stop_sequence: sindex = outstr.find(trim_str) @@ -2241,6 +2282,32 @@ def generate(genparams, stream_flag=False): outstr = outstr[:sindex] return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason,"prompt_tokens":ret.prompt_tokens, "completion_tokens": ret.completion_tokens} +def continuous_batching_python_eligible(genparams, api_format): + if getattr(args, "continuous_batching", 0) <= 1 or api_format <= 0: + return False + model_path = str(getattr(args, "model_param", "") or "").lower() + if model_path and not model_path.endswith(".gguf"): + return False + if not getattr(args, "noshift", False) or getattr(args, "smartcontext", False) or getattr(args, "draftmodel", "") or getattr(args, "mmproj", "") or getattr(args, "enableguidance", False): + return False + if genparams.get("memory") or genparams.get("negative_prompt") or genparams.get("images") or genparams.get("audio"): + return False + if genparams.get("ban_eos_token", False): + return False + if genparams.get("grammar") or genparams.get("grammar_retain_state") or genparams.get("logit_bias") or genparams.get("banned_tokens") or genparams.get("banned_strings"): + return False + if tryparsefloat(genparams.get("dry_multiplier", 0), 0) or tryparseint(genparams.get("mirostat", 0), 0) or tryparsefloat(genparams.get("xtc_probability", 0), 0) or tryparsefloat(genparams.get("nsigma", 0), 0): + return False + if tryparsefloat(genparams.get("smoothing_factor", 0), 0) or tryparsefloat(genparams.get("adaptive_target", -1), -1) > 0 or genparams.get("using_openai_tools", False): + return False + if tryparsefloat(genparams.get("top_a", 0), 0) or tryparsefloat(genparams.get("tfs", 1), 1) != 1 or tryparsefloat(genparams.get("dynatemp_range", 0), 0): + return False + if genparams.get("sampler_order") and genparams.get("sampler_order") != [6, 0, 1, 3, 4, 2, 5]: + return False + if genparams.get("reasoning_effort"): + return False + return True + def sd_get_info(): info = handle.sd_get_info() if info.status == 0: @@ -4922,19 +4989,29 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler): incomplete_token_buffer = bytearray() async_sleep_short = 0.02 await asyncio.sleep(0.35) #anti race condition, prevent check from overtaking generate + batch_request_id = genparams.get('_batch_request_id', -1) + batch_final_result = None try: tokenReserve = "" #keeps fully formed tokens that we cannot send out yet while True: - streamDone = handle.has_finished() #exit next loop on done + if batch_request_id < 0: + batch_request_id = genparams.get('_batch_request_id', -1) + if genparams.get('_batch_expected', False) and batch_request_id < 0 and not genparams.get('_batch_fallback', False): + await asyncio.sleep(async_sleep_short) + continue + using_batch_stream = batch_request_id >= 0 + streamDone = handle.batch_generate_has_finished(batch_request_id) if using_batch_stream else handle.has_finished() #exit next loop on done if streamDone: - sr = handle.get_last_stop_reason() + if using_batch_stream and batch_final_result is None: + batch_final_result = handle.batch_generate_result(batch_request_id) + sr = batch_final_result.stopreason if using_batch_stream else handle.get_last_stop_reason() currfinishreason = "error" if sr==-2 else ("length" if (sr!=1) else "stop") - prompttokens = handle.get_last_input_count() + prompttokens = batch_final_result.prompt_tokens if using_batch_stream else handle.get_last_input_count() tokenStr = "" - streamcount = handle.get_stream_count() + streamcount = handle.batch_generate_stream_count(batch_request_id) if using_batch_stream else handle.get_stream_count() while current_token < streamcount: - token = handle.new_token(current_token) + token = handle.batch_generate_new_token(batch_request_id, current_token) if using_batch_stream else handle.new_token(current_token) if token is None: # Token isnt ready yet, received nullpointer break @@ -5147,7 +5224,8 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler): if streamDone: # content_part.done, reply full text await asyncio.sleep(async_sleep_short) - finaltxt = handle.get_pending_output().decode("UTF-8", "ignore") + finalraw = handle.batch_generate_pending_output(batch_request_id) if using_batch_stream else handle.get_pending_output() + finaltxt = finalraw.decode("UTF-8", "ignore") await asyncio.sleep(async_sleep_short) done_event = json.dumps({"type": "response.output_text.done", "item_id": item_id, "output_index": 0, "sequence_number":rseq_num, "content_index": 0, "text": finaltxt}) rseq_num += 1 @@ -5156,7 +5234,7 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler): item_done = json.dumps({"type": "response.output_item.done", "output_index": 0, "sequence_number":rseq_num, "item": { "type": "message", "id": item_id, "status": "completed", "role": "assistant", "content": [{"type": "output_text", "text": finaltxt, "annotations": [], "logprobs": []}]}}) rseq_num += 1 await self.send_oai_responses_sse_event("response.output_item.done",item_done) - usage_pp = handle.get_last_input_count() + usage_pp = batch_final_result.prompt_tokens if using_batch_stream else handle.get_last_input_count() usage_gen = current_token res = self.prepare_basic_responses_body(resp_id,genparams) res["completed_at"] = int(time.time()) @@ -5204,8 +5282,17 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler): except Exception as ex: print("Token streaming was interrupted or aborted!") print(ex) - handle.abort_generate() + if batch_request_id >= 0: + handle.batch_generate_abort(batch_request_id) + else: + handle.abort_generate() await asyncio.sleep(0.2) #short delay + finally: + if batch_request_id >= 0: + handle.batch_generate_release(batch_request_id) + genparams.pop('_batch_request_id', None) + genparams.pop('_batch_expected', None) + genparams.pop('_batch_fallback', None) # flush buffers, sleep a bit to make sure all data sent, and then force close the connection self.wfile.flush() @@ -5271,7 +5358,7 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler): pass def get_multiplayer_idle_state(self,userid): - if modelbusy.locked(): + if modelbusy.locked() or batched_request_runner_count>0: return False for key, value in multiplayer_lastactive.items(): if key!=userid and time.time()-value<6: #6s to idle @@ -5308,7 +5395,7 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler): return True def noscript_webui(self): - global modelbusy, sslvalid + global modelbusy, sslvalid, batched_request_runner_count parsed_url = urllib.parse.urlparse(self.path) parsed_dict = urllib.parse.parse_qs(parsed_url.query) reply = "" @@ -5348,7 +5435,7 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler): else: gencommand = False - if modelbusy.locked(): + if modelbusy.locked() or batched_request_runner_count>0: status = "Model is currently busy, try again later." elif gencommand: if prompt=="" or max_length<=0: @@ -5587,7 +5674,7 @@ Change Mode
"total_tts_gens": totalttsgens, "total_transcribe_gens": totaltranscribegens, "queue": requestsinqueue, - "idle": (0 if modelbusy.locked() else 1), + "idle": (0 if (modelbusy.locked() or batched_request_runner_count>0) else 1), "hordeexitcounter": exitcounter, "uptime": uptime, "idletime": idletime, @@ -5850,7 +5937,7 @@ Change Mode
def do_POST(self): global thinkformats - global modelbusy, requestsinqueue, currentusergenkey, totalgens, pendingabortkey, lastuploadedcomfyimg, lastgeneratedcomfyimg, multiplayer_turn_major, multiplayer_turn_minor, multiplayer_story_data_compressed, multiplayer_dataformat, multiplayer_lastactive, net_save_slots, has_vision_support, savestate_limit, mcp_lock + global modelbusy, batched_request_runner_count, requestsinqueue, currentusergenkey, totalgens, pendingabortkey, lastuploadedcomfyimg, lastgeneratedcomfyimg, multiplayer_turn_major, multiplayer_turn_minor, multiplayer_story_data_compressed, multiplayer_dataformat, multiplayer_lastactive, net_save_slots, has_vision_support, savestate_limit, mcp_lock global autoswapmode, textName, sttName, ttsName, embedName, musicName, imageName, mmprojName contlenstr = self.headers['content-length'] content_length = 0 @@ -6323,6 +6410,7 @@ Change Mode
"type": "service_unavailable", }}).encode()) return + is_batchable_req = False if reqblocking: requestsinqueue = (requestsinqueue - 1) if requestsinqueue > 0 else 0 @@ -6531,10 +6619,22 @@ Change Mode
if args.foreground: bring_terminal_to_foreground() + #if it's a non-batchable request and we already have batching ongoing, stall this request + if batched_request_runner_count > 0 and not continuous_batching_python_eligible(genparams, api_format): + with batched_cond: + while batched_request_runner_count > 0: + batched_cond.wait() + if api_format > 0: #text gen # Check if streaming chat completions, if so, set stream mode to true if (api_format == 4 or api_format == 3 or api_format == 8 or api_format == 9) and "stream" in genparams and genparams["stream"]: sse_stream_flag = True + if continuous_batching_python_eligible(genparams, api_format): + genparams['_batch_expected'] = True + modelbusy.release() + is_batchable_req = True + with batched_cond: + batched_request_runner_count += 1 gendat = asyncio.run(self.handle_request(genparams, api_format, sse_stream_flag)) @@ -6887,7 +6987,12 @@ Change Mode
finally: time.sleep(0.05) - modelbusy.release() + if is_batchable_req: + with batched_cond: + batched_request_runner_count -= 1 + batched_cond.notify_all() + else: + modelbusy.release() self.send_response(404) self.end_headers(content_type='text/html') @@ -11380,6 +11485,7 @@ if __name__ == '__main__': advparser.add_argument("--analyze", metavar=('[filename]'), help="Reads the metadata, weight types and tensor names in any GGUF file.", default="") advparser.add_argument("--maingpu","--main-gpu","-mg", help="Only used in a multi-gpu setup. Sets the index of the main GPU that will be used.",metavar=('[Device ID]'), type=int, default=-1) advparser.add_argument("--batchsize","--blasbatchsize","--batch-size","-b", help="Sets the batch size used in batched processing (default 512). Setting it to -1 disables batched mode, but keeps other benefits like GPU offload.", type=int,choices=[-1,16,32,64,128,256,512,1024,2048,4096], default=512) + advparser.add_argument("--continuous-batching","--contbatch", help=argparse.SUPPRESS, metavar=('[slots]'), type=check_range(int,0,64), default=0) advparser.add_argument("--blasthreads","--batchthreads","--threadsbatch","--threads-batch", help="Use a different number of threads during batching if specified. Otherwise, has the same value as --threads",metavar=('[threads]'), type=int, default=0) advparser.add_argument("--splitmode","-sm","--split-mode", help="How to split the model across multiple GPUs", metavar=('[split mode]'), type=str, choices=splitmode_choices, default=splitmode_choices[0]) advparser.add_argument("--nommq", help="Disables MMQ, only used for cuda backend. This flag may be removed in future.", action='store_true') diff --git a/model_adapter.h b/model_adapter.h index dc6f8b86f..503573124 100644 --- a/model_adapter.h +++ b/model_adapter.h @@ -84,6 +84,15 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in generation_outputs gpttype_generate(const generation_inputs inputs); bool gpttype_generate_abort(); std::string gpttype_get_chat_template(); +bool gpttype_batch_generate_enabled(); +int gpttype_batch_generate_submit(const generation_inputs inputs); +bool gpttype_batch_generate_has_finished(int request_id); +int gpttype_batch_generate_stream_count(int request_id); +const char * gpttype_batch_generate_new_token(int request_id, int idx); +const char * gpttype_batch_generate_pending_output(int request_id); +generation_outputs gpttype_batch_generate_result(int request_id); +bool gpttype_batch_generate_abort(int request_id); +void gpttype_batch_generate_release(int request_id); const std::string & gpttype_get_pending_output(); std::vector gpttype_get_token_arr(const std::string & input, bool addbos);