mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-16 19:59:16 +00:00
feat: add a primitive form of continuous batching (#2167)
* feat: add a primitive form of continuous batching * fix: deadlock in batching fallback * fix: windows build * chore: suppress the contbatch arg from --help * feat: batch-aware rep_pen_slope * fix: automatically disable shifting when batching is enabled * fix: mixed-path state corruption * fix: attempt to fully separate the two pipelines * added a semaphore to prevent non-batchable requests from starting while batched requests are running --------- Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com>
This commit is contained in:
parent
a47037637c
commit
c03302b670
5 changed files with 924 additions and 16 deletions
27
expose.cpp
27
expose.cpp
|
|
@ -258,6 +258,33 @@ extern "C"
|
|||
bool has_finished() {
|
||||
return generation_finished;
|
||||
}
|
||||
bool batch_generate_enabled() {
|
||||
return gpttype_batch_generate_enabled();
|
||||
}
|
||||
int batch_generate_submit(const generation_inputs inputs) {
|
||||
return gpttype_batch_generate_submit(inputs);
|
||||
}
|
||||
bool batch_generate_has_finished(int request_id) {
|
||||
return gpttype_batch_generate_has_finished(request_id);
|
||||
}
|
||||
int batch_generate_stream_count(int request_id) {
|
||||
return gpttype_batch_generate_stream_count(request_id);
|
||||
}
|
||||
const char * batch_generate_new_token(int request_id, int idx) {
|
||||
return gpttype_batch_generate_new_token(request_id, idx);
|
||||
}
|
||||
const char * batch_generate_pending_output(int request_id) {
|
||||
return gpttype_batch_generate_pending_output(request_id);
|
||||
}
|
||||
generation_outputs batch_generate_result(int request_id) {
|
||||
return gpttype_batch_generate_result(request_id);
|
||||
}
|
||||
bool batch_generate_abort(int request_id) {
|
||||
return gpttype_batch_generate_abort(request_id);
|
||||
}
|
||||
void batch_generate_release(int request_id) {
|
||||
gpttype_batch_generate_release(request_id);
|
||||
}
|
||||
bool has_audio_support()
|
||||
{
|
||||
return audio_multimodal_supported;
|
||||
|
|
|
|||
11
expose.h
11
expose.h
|
|
@ -84,6 +84,7 @@ struct load_model_inputs
|
|||
const char * devices_override = nullptr;
|
||||
const bool quiet = false;
|
||||
const int debugmode = 0;
|
||||
const int continuous_batching_slots = 0;
|
||||
};
|
||||
struct generation_inputs
|
||||
{
|
||||
|
|
@ -385,3 +386,13 @@ extern int total_transcribe_gens;
|
|||
extern int last_draft_success;
|
||||
extern int last_draft_failed;
|
||||
extern stop_reason last_stop_reason;
|
||||
|
||||
bool gpttype_batch_generate_enabled();
|
||||
int gpttype_batch_generate_submit(const generation_inputs inputs);
|
||||
bool gpttype_batch_generate_has_finished(int request_id);
|
||||
int gpttype_batch_generate_stream_count(int request_id);
|
||||
const char * gpttype_batch_generate_new_token(int request_id, int idx);
|
||||
const char * gpttype_batch_generate_pending_output(int request_id);
|
||||
generation_outputs gpttype_batch_generate_result(int request_id);
|
||||
bool gpttype_batch_generate_abort(int request_id);
|
||||
void gpttype_batch_generate_release(int request_id);
|
||||
|
|
|
|||
|
|
@ -22,6 +22,11 @@
|
|||
#include <cctype>
|
||||
#include <locale>
|
||||
#include <chrono>
|
||||
#include <algorithm>
|
||||
#include <condition_variable>
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
|
||||
#include "utils.h"
|
||||
#include "llmutils.h"
|
||||
|
|
@ -76,6 +81,7 @@ int last_draft_success = 0;
|
|||
int last_draft_failed = 0;
|
||||
stop_reason last_stop_reason = stop_reason::INVALID;
|
||||
std::vector<std::string> generated_tokens;
|
||||
static int continuous_batching_slots = 0;
|
||||
|
||||
llama_grammar * grammar = nullptr; //currently used grammar
|
||||
llama_grammar_parser parsed_grammar;
|
||||
|
|
@ -2197,6 +2203,11 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
kcpp_pipeline_parallelism = inputs.pipelineparallel;
|
||||
kcpp_data->n_batch = GetBatchSize(inputs.batchsize, in_file_format);
|
||||
kcpp_data->n_ubatch = kcpp_data->n_batch;
|
||||
continuous_batching_slots = (isGguf && inputs.continuous_batching_slots > 1) ? inputs.continuous_batching_slots : 0;
|
||||
if(continuous_batching_slots > 0)
|
||||
{
|
||||
printf("Continuous batching: prepared %d GGUF sequence slots.\n", continuous_batching_slots);
|
||||
}
|
||||
kcpp_data->vision_min_tokens = inputs.visionmintokens;
|
||||
kcpp_data->vision_max_tokens = inputs.visionmaxtokens;
|
||||
if(isGguf && kcpp_pipeline_parallelism)
|
||||
|
|
@ -2461,6 +2472,10 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
|
||||
llama_ctx_params.n_batch = kcpp_data->n_batch;
|
||||
llama_ctx_params.n_ubatch = kcpp_data->n_ubatch;
|
||||
if(continuous_batching_slots > 0)
|
||||
{
|
||||
llama_ctx_params.n_seq_max = continuous_batching_slots + 1;
|
||||
}
|
||||
llama_ctx_params.n_threads = kcpp_data->n_threads;
|
||||
llama_ctx_params.n_threads_batch = kcpp_data->n_blasthreads;
|
||||
|
||||
|
|
@ -3269,6 +3284,745 @@ bool gpttype_generate_abort()
|
|||
return true;
|
||||
}
|
||||
|
||||
enum class BatchState
|
||||
{
|
||||
WAITING,
|
||||
PREFILL,
|
||||
GENERATING,
|
||||
FINISHED,
|
||||
FAILED,
|
||||
ABORTED,
|
||||
};
|
||||
|
||||
struct BatchGenerateRequest
|
||||
{
|
||||
int id = 0;
|
||||
int slot = -1;
|
||||
BatchState state = BatchState::WAITING;
|
||||
std::string prompt;
|
||||
std::vector<std::string> stop_sequences;
|
||||
int max_context_length = 0;
|
||||
int max_length = 0;
|
||||
int seed = 0;
|
||||
float temperature = 0.0f;
|
||||
int top_k = 0;
|
||||
float top_p = 1.0f;
|
||||
float min_p = 0.0f;
|
||||
float typical_p = 1.0f;
|
||||
float rep_pen = 1.0f;
|
||||
float rep_pen_slope = 1.0f;
|
||||
int rep_pen_range = 0;
|
||||
float presence_penalty = 0.0f;
|
||||
bool allow_eos_token = true;
|
||||
bool bypass_eos_token = false;
|
||||
bool render_special = false;
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
int prompt_pos = 0;
|
||||
int n_past = 0;
|
||||
bool has_pending = false;
|
||||
llama_token pending_token = 0;
|
||||
int i_batch = -1;
|
||||
llama_sampler * sampler = nullptr;
|
||||
std::vector<std::string> generated_pieces;
|
||||
std::string output;
|
||||
int prompt_token_count = 0;
|
||||
int completion_token_count = 0;
|
||||
std::chrono::steady_clock::time_point start_time;
|
||||
stop_reason finish_reason = stop_reason::INVALID;
|
||||
bool abort_requested = false;
|
||||
generation_outputs result;
|
||||
|
||||
~BatchGenerateRequest()
|
||||
{
|
||||
if(sampler)
|
||||
{
|
||||
llama_sampler_free(sampler);
|
||||
sampler = nullptr;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static std::mutex batch_mutex;
|
||||
static std::condition_variable batch_cv;
|
||||
static std::deque<int> batch_waiting;
|
||||
static std::vector<std::unique_ptr<BatchGenerateRequest>> batch_requests;
|
||||
static std::thread batch_worker_thread;
|
||||
static bool batch_worker_stop = false;
|
||||
static bool batch_worker_started = false;
|
||||
static bool batch_legacy_active = false;
|
||||
static bool batch_touched_since_legacy = false;
|
||||
static int batch_legacy_waiting = 0;
|
||||
static int batch_next_request_id = 1;
|
||||
static std::string batch_empty_string = "";
|
||||
|
||||
static BatchGenerateRequest * batch_find_request_locked(int request_id)
|
||||
{
|
||||
for(auto & req : batch_requests)
|
||||
{
|
||||
if(req && req->id == request_id)
|
||||
{
|
||||
return req.get();
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static bool batch_is_live_state(BatchState state)
|
||||
{
|
||||
return state == BatchState::WAITING || state == BatchState::PREFILL || state == BatchState::GENERATING;
|
||||
}
|
||||
|
||||
static bool batch_has_live_locked()
|
||||
{
|
||||
for(const auto & req : batch_requests)
|
||||
{
|
||||
if(req && batch_is_live_state(req->state))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static void batch_invalidate_legacy_context_locked()
|
||||
{
|
||||
if(!batch_touched_since_legacy)
|
||||
{
|
||||
return;
|
||||
}
|
||||
batch_touched_since_legacy = false;
|
||||
n_past = 0;
|
||||
current_context_tokens.clear();
|
||||
last_n_tokens.clear();
|
||||
smartcontext.clear();
|
||||
loaded_latest_logits.clear();
|
||||
if(llama_ctx_v4)
|
||||
{
|
||||
llama_memory_seq_rm(llama_get_memory(llama_ctx_v4), 0, -1, -1);
|
||||
}
|
||||
if(draft_ctx)
|
||||
{
|
||||
llama_memory_seq_rm(llama_get_memory(draft_ctx), 0, -1, -1);
|
||||
}
|
||||
if(debugmode==1 && !is_quiet)
|
||||
{
|
||||
printf("\n[Continuous batching touched shared context; forcing next legacy generation to reprocess prompt]\n");
|
||||
}
|
||||
}
|
||||
|
||||
class BatchLegacyGuard
|
||||
{
|
||||
public:
|
||||
BatchLegacyGuard()
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(batch_mutex);
|
||||
batch_legacy_waiting++;
|
||||
batch_cv.notify_all();
|
||||
batch_cv.wait(lock, [](){ return !batch_has_live_locked(); });
|
||||
batch_legacy_waiting--;
|
||||
batch_invalidate_legacy_context_locked();
|
||||
batch_legacy_active = true;
|
||||
}
|
||||
|
||||
~BatchLegacyGuard()
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(batch_mutex);
|
||||
batch_legacy_active = false;
|
||||
batch_cv.notify_all();
|
||||
}
|
||||
};
|
||||
|
||||
static bool batch_inputs_eligible(const generation_inputs & inputs)
|
||||
{
|
||||
if(continuous_batching_slots <= 1 || file_format != FileFormat::GGUF_GENERIC || !llama_ctx_v4 || !kcpp_data)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(draft_ctx || guidance_ctx || clp_ctx_v || clp_ctx_a)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(kcpp_data->use_smartcontext || kcpp_data->use_contextshift || kcpp_data->smartcache)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(inputs.memory && std::string(inputs.memory).size() > 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(inputs.negative_prompt && std::string(inputs.negative_prompt).size() > 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(inputs.images_len > 0 || inputs.audio_len > 0 || inputs.guidance_scale != 1.0f || !inputs.allow_eos_token)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(inputs.grammar && std::string(inputs.grammar).size() > 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(inputs.logit_biases_len > 0 || inputs.banned_tokens_len > 0 || inputs.dry_multiplier > 0.0f)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(inputs.mirostat != 0 || inputs.xtc_probability > 0.0f || inputs.nsigma > 0.0f || inputs.smoothing_factor > 0.0f || inputs.adaptive_target > 0.0f)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(inputs.top_a > 0.0f || inputs.tfs != 1.0f || inputs.dynatemp_range > 0.0f)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
static const int default_sampler_order[] = {6, 0, 1, 3, 4, 2, 5};
|
||||
if(inputs.sampler_len > 0)
|
||||
{
|
||||
if(inputs.sampler_len != 7)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
for(int i = 0; i < 7; ++i)
|
||||
{
|
||||
if((int) inputs.sampler_order[i] != default_sampler_order[i])
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if(inputs.reasoning_budget >= 0 || inputs.tool_call_fix)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
struct BatchRepPenSampler
|
||||
{
|
||||
int32_t penalty_last_n = 0;
|
||||
float penalty_repeat = 1.0f;
|
||||
float penalty_slope = 1.0f;
|
||||
float penalty_present = 0.0f;
|
||||
std::vector<llama_token> prev;
|
||||
};
|
||||
|
||||
static const char * batch_rep_pen_name(const llama_sampler * /*smpl*/)
|
||||
{
|
||||
return "kcpp-batch-rep-pen";
|
||||
}
|
||||
|
||||
static void batch_rep_pen_accept(llama_sampler * smpl, llama_token token)
|
||||
{
|
||||
auto * ctx = (BatchRepPenSampler *) smpl->ctx;
|
||||
if(ctx->penalty_last_n <= 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
if(ctx->prev.size() >= (size_t) ctx->penalty_last_n)
|
||||
{
|
||||
ctx->prev.erase(ctx->prev.begin());
|
||||
}
|
||||
ctx->prev.push_back(token);
|
||||
}
|
||||
|
||||
static void batch_rep_pen_apply(llama_sampler * smpl, llama_token_data_array * cur_p)
|
||||
{
|
||||
auto * ctx = (BatchRepPenSampler *) smpl->ctx;
|
||||
int last_n_repeat = std::min((int) ctx->prev.size(), ctx->penalty_last_n);
|
||||
if(last_n_repeat <= 0 || (ctx->penalty_repeat == 1.0f && ctx->penalty_present == 0.0f))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const llama_token * last_tokens = ctx->prev.data() + ctx->prev.size() - last_n_repeat;
|
||||
std::unordered_set<llama_token> tokens_near(last_tokens + last_n_repeat / 2, last_tokens + last_n_repeat);
|
||||
std::unordered_set<llama_token> tokens_far(last_tokens, last_tokens + last_n_repeat / 2);
|
||||
|
||||
float penalty_reduced = ctx->penalty_repeat;
|
||||
if(penalty_reduced > 1.0f)
|
||||
{
|
||||
penalty_reduced = 1.0f + ((ctx->penalty_repeat - 1.0f) * ctx->penalty_slope);
|
||||
}
|
||||
|
||||
for(size_t i = 0; i < cur_p->size; ++i)
|
||||
{
|
||||
const bool token_in_near = tokens_near.find(cur_p->data[i].id) != tokens_near.end();
|
||||
const bool token_in_far = tokens_far.find(cur_p->data[i].id) != tokens_far.end();
|
||||
if(!token_in_near && !token_in_far)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
float penalty = token_in_near ? ctx->penalty_repeat : penalty_reduced;
|
||||
if(cur_p->data[i].logit <= 0)
|
||||
{
|
||||
cur_p->data[i].logit *= penalty;
|
||||
}
|
||||
else
|
||||
{
|
||||
cur_p->data[i].logit /= penalty;
|
||||
}
|
||||
cur_p->data[i].logit -= ctx->penalty_present;
|
||||
}
|
||||
|
||||
cur_p->sorted = false;
|
||||
}
|
||||
|
||||
static void batch_rep_pen_reset(llama_sampler * smpl)
|
||||
{
|
||||
auto * ctx = (BatchRepPenSampler *) smpl->ctx;
|
||||
ctx->prev.clear();
|
||||
}
|
||||
|
||||
static llama_sampler * batch_rep_pen_clone(const llama_sampler * smpl)
|
||||
{
|
||||
const auto * ctx = (const BatchRepPenSampler *) smpl->ctx;
|
||||
auto * result = llama_sampler_init(smpl->iface, new BatchRepPenSampler {
|
||||
ctx->penalty_last_n,
|
||||
ctx->penalty_repeat,
|
||||
ctx->penalty_slope,
|
||||
ctx->penalty_present,
|
||||
ctx->prev,
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
static void batch_rep_pen_free(llama_sampler * smpl)
|
||||
{
|
||||
delete (BatchRepPenSampler *) smpl->ctx;
|
||||
}
|
||||
|
||||
static llama_sampler_i batch_rep_pen_i = {
|
||||
/* .name = */ batch_rep_pen_name,
|
||||
/* .accept = */ batch_rep_pen_accept,
|
||||
/* .apply = */ batch_rep_pen_apply,
|
||||
/* .reset = */ batch_rep_pen_reset,
|
||||
/* .clone = */ batch_rep_pen_clone,
|
||||
/* .free = */ batch_rep_pen_free,
|
||||
/* .backend_init = */ nullptr,
|
||||
/* .backend_accept = */ nullptr,
|
||||
/* .backend_apply = */ nullptr,
|
||||
/* .backend_set_input = */ nullptr,
|
||||
};
|
||||
|
||||
static llama_sampler * batch_rep_pen_init(int32_t penalty_last_n, float penalty_repeat, float penalty_slope, float penalty_present)
|
||||
{
|
||||
penalty_last_n = std::max(penalty_last_n, 0);
|
||||
if(penalty_slope <= 0.0f || penalty_slope > 1.0f)
|
||||
{
|
||||
penalty_slope = 1.0f;
|
||||
}
|
||||
return llama_sampler_init(&batch_rep_pen_i, new BatchRepPenSampler {
|
||||
penalty_last_n,
|
||||
penalty_repeat <= 0.0f ? 1.0f : penalty_repeat,
|
||||
penalty_slope,
|
||||
penalty_present,
|
||||
{},
|
||||
});
|
||||
}
|
||||
|
||||
static llama_sampler * batch_build_sampler(const BatchGenerateRequest & req)
|
||||
{
|
||||
llama_sampler_chain_params params = llama_sampler_chain_default_params();
|
||||
llama_sampler * chain = llama_sampler_chain_init(params);
|
||||
llama_sampler_chain_add(chain, batch_rep_pen_init(
|
||||
req.rep_pen_range,
|
||||
req.rep_pen,
|
||||
req.rep_pen_slope,
|
||||
req.presence_penalty));
|
||||
if(req.top_k > 0)
|
||||
{
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_top_k(req.top_k));
|
||||
}
|
||||
if(req.top_p > 0.0f && req.top_p < 1.0f)
|
||||
{
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_top_p(req.top_p, 1));
|
||||
}
|
||||
if(req.min_p > 0.0f)
|
||||
{
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_min_p(req.min_p, 1));
|
||||
}
|
||||
if(req.typical_p > 0.0f && req.typical_p < 1.0f)
|
||||
{
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_typical(req.typical_p, 1));
|
||||
}
|
||||
if(req.temperature > 0.0f)
|
||||
{
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_temp(req.temperature));
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_dist(req.seed < 0 ? LLAMA_DEFAULT_SEED : (uint32_t) req.seed));
|
||||
}
|
||||
else
|
||||
{
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_greedy());
|
||||
}
|
||||
return chain;
|
||||
}
|
||||
|
||||
static void batch_finish_request_locked(BatchGenerateRequest & req, stop_reason reason)
|
||||
{
|
||||
auto finish_time = std::chrono::steady_clock::now();
|
||||
float total_time = req.start_time.time_since_epoch().count() == 0 ? 0.0f : std::chrono::duration<float>(finish_time - req.start_time).count();
|
||||
float generated_tps = total_time > 0.0f ? (float) req.completion_token_count / total_time : 0.0f;
|
||||
req.finish_reason = reason;
|
||||
req.result.status = (reason == stop_reason::ERROR_ENCOUNTERED) ? 0 : 1;
|
||||
req.result.stopreason = reason;
|
||||
req.result.prompt_tokens = req.prompt_token_count;
|
||||
req.result.completion_tokens = req.completion_token_count;
|
||||
req.result.text = req.output.c_str();
|
||||
req.state = reason == stop_reason::ERROR_ENCOUNTERED ? BatchState::FAILED : (reason == stop_reason::INVALID ? BatchState::ABORTED : BatchState::FINISHED);
|
||||
if(req.slot >= 0 && llama_ctx_v4)
|
||||
{
|
||||
llama_memory_seq_rm(llama_get_memory(llama_ctx_v4), req.slot, -1, -1);
|
||||
}
|
||||
req.slot = -1;
|
||||
printf("\n[%s] BatchRequest:%d, Prompt:%d, Generated:%d/%d in %.2fs (%.2fT/s), Stop:%d",
|
||||
get_timestamp_str().c_str(), req.id, req.prompt_token_count, req.completion_token_count, req.max_length, total_time, generated_tps, (int) reason);
|
||||
fflush(stdout);
|
||||
batch_cv.notify_all();
|
||||
}
|
||||
|
||||
static bool batch_output_hit_stop(const BatchGenerateRequest & req)
|
||||
{
|
||||
for(const auto & stopper : req.stop_sequences)
|
||||
{
|
||||
if(!stopper.empty() && req.output.find(stopper) != std::string::npos)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool batch_claim_waiting_locked()
|
||||
{
|
||||
bool claimed = false;
|
||||
for(int slot = 1; slot <= continuous_batching_slots && !batch_waiting.empty(); ++slot)
|
||||
{
|
||||
bool occupied = false;
|
||||
for(const auto & req : batch_requests)
|
||||
{
|
||||
if(req && req->slot == slot && batch_is_live_state(req->state))
|
||||
{
|
||||
occupied = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(occupied)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
int request_id = batch_waiting.front();
|
||||
batch_waiting.pop_front();
|
||||
BatchGenerateRequest * req = batch_find_request_locked(request_id);
|
||||
if(!req || req->state != BatchState::WAITING)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
req->slot = slot;
|
||||
req->state = BatchState::PREFILL;
|
||||
batch_touched_since_legacy = true;
|
||||
TokenizeString(req->prompt, req->prompt_tokens, file_format, add_bos_token);
|
||||
if(req->prompt_tokens.empty())
|
||||
{
|
||||
TokenizeString("", req->prompt_tokens, file_format, add_bos_token);
|
||||
}
|
||||
int n_ctx = req->max_context_length > 0 ? std::min(req->max_context_length, kcpp_data->n_ctx) : kcpp_data->n_ctx;
|
||||
if(req->max_length > 0 && (int) req->prompt_tokens.size() + req->max_length > n_ctx)
|
||||
{
|
||||
int keep = std::max(1, n_ctx - req->max_length);
|
||||
if((int) req->prompt_tokens.size() > keep)
|
||||
{
|
||||
req->prompt_tokens.erase(req->prompt_tokens.begin(), req->prompt_tokens.end() - keep);
|
||||
}
|
||||
}
|
||||
req->prompt_token_count = req->prompt_tokens.size();
|
||||
req->sampler = batch_build_sampler(*req);
|
||||
for(llama_token token : req->prompt_tokens)
|
||||
{
|
||||
llama_sampler_accept(req->sampler, token);
|
||||
}
|
||||
req->prompt_pos = 0;
|
||||
req->n_past = 0;
|
||||
req->has_pending = false;
|
||||
req->i_batch = -1;
|
||||
req->start_time = std::chrono::steady_clock::now();
|
||||
llama_memory_seq_rm(llama_get_memory(llama_ctx_v4), slot, -1, -1);
|
||||
claimed = true;
|
||||
}
|
||||
return claimed;
|
||||
}
|
||||
|
||||
static void batch_worker_loop()
|
||||
{
|
||||
const int batch_cap = std::max(1, kcpp_data ? kcpp_data->n_batch : 512);
|
||||
llama_batch batch = llama_batch_init(batch_cap, 0, 1);
|
||||
while(true)
|
||||
{
|
||||
std::vector<int> decode_ids;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(batch_mutex);
|
||||
batch_cv.wait_for(lock, std::chrono::milliseconds(5), [](){
|
||||
return batch_worker_stop || (!batch_legacy_active && batch_has_live_locked());
|
||||
});
|
||||
if(batch_worker_stop)
|
||||
{
|
||||
break;
|
||||
}
|
||||
if(batch_legacy_active)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
batch_claim_waiting_locked();
|
||||
common_batch_clear(batch);
|
||||
for(auto & req_ptr : batch_requests)
|
||||
{
|
||||
if(!req_ptr || !batch_is_live_state(req_ptr->state) || req_ptr->slot < 0 || batch.n_tokens >= batch_cap)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
BatchGenerateRequest & req = *req_ptr;
|
||||
req.i_batch = -1;
|
||||
if(req.abort_requested)
|
||||
{
|
||||
batch_finish_request_locked(req, stop_reason::INVALID);
|
||||
continue;
|
||||
}
|
||||
if(req.state == BatchState::PREFILL)
|
||||
{
|
||||
while(req.prompt_pos < (int) req.prompt_tokens.size() && batch.n_tokens < batch_cap)
|
||||
{
|
||||
bool is_last = req.prompt_pos == (int) req.prompt_tokens.size() - 1;
|
||||
if(is_last)
|
||||
{
|
||||
req.i_batch = batch.n_tokens;
|
||||
}
|
||||
common_batch_add(batch, req.prompt_tokens[req.prompt_pos], req.n_past, { req.slot }, is_last);
|
||||
req.prompt_pos++;
|
||||
req.n_past++;
|
||||
}
|
||||
if(req.prompt_pos == (int) req.prompt_tokens.size())
|
||||
{
|
||||
req.state = BatchState::GENERATING;
|
||||
}
|
||||
}
|
||||
else if(req.state == BatchState::GENERATING && req.has_pending)
|
||||
{
|
||||
req.i_batch = batch.n_tokens;
|
||||
common_batch_add(batch, req.pending_token, req.n_past, { req.slot }, true);
|
||||
req.n_past++;
|
||||
req.has_pending = false;
|
||||
}
|
||||
}
|
||||
if(batch.n_tokens == 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
for(auto & req_ptr : batch_requests)
|
||||
{
|
||||
if(req_ptr && req_ptr->i_batch >= 0)
|
||||
{
|
||||
decode_ids.push_back(req_ptr->id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int decode_status = llama_decode(llama_ctx_v4, batch);
|
||||
|
||||
std::lock_guard<std::mutex> lock(batch_mutex);
|
||||
if(decode_status != 0)
|
||||
{
|
||||
for(int request_id : decode_ids)
|
||||
{
|
||||
BatchGenerateRequest * req = batch_find_request_locked(request_id);
|
||||
if(req && batch_is_live_state(req->state))
|
||||
{
|
||||
batch_finish_request_locked(*req, stop_reason::ERROR_ENCOUNTERED);
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(llama_get_model(llama_ctx_v4));
|
||||
llama_token eos = llama_vocab_eos(vocab);
|
||||
for(int request_id : decode_ids)
|
||||
{
|
||||
BatchGenerateRequest * req = batch_find_request_locked(request_id);
|
||||
if(!req || req->state != BatchState::GENERATING || req->i_batch < 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
llama_token sampled = llama_sampler_sample(req->sampler, llama_ctx_v4, req->i_batch);
|
||||
if(!req->allow_eos_token && !req->bypass_eos_token && sampled == eos)
|
||||
{
|
||||
sampled = llama_sampler_sample(req->sampler, llama_ctx_v4, req->i_batch);
|
||||
}
|
||||
req->completion_token_count++;
|
||||
if(sampled == eos && !req->bypass_eos_token)
|
||||
{
|
||||
batch_finish_request_locked(*req, stop_reason::EOS_TOKEN_HIT);
|
||||
continue;
|
||||
}
|
||||
std::string piece = FileFormatTokenizeID(sampled, file_format, req->render_special);
|
||||
req->generated_pieces.push_back(piece);
|
||||
req->output += piece;
|
||||
if(batch_output_hit_stop(*req))
|
||||
{
|
||||
batch_finish_request_locked(*req, stop_reason::CUSTOM_STOPPER);
|
||||
continue;
|
||||
}
|
||||
if(req->max_length > 0 && req->completion_token_count >= req->max_length)
|
||||
{
|
||||
batch_finish_request_locked(*req, stop_reason::OUT_OF_TOKENS);
|
||||
continue;
|
||||
}
|
||||
req->pending_token = sampled;
|
||||
req->has_pending = true;
|
||||
req->i_batch = -1;
|
||||
}
|
||||
}
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
|
||||
static void batch_start_worker_locked()
|
||||
{
|
||||
if(batch_worker_started)
|
||||
{
|
||||
return;
|
||||
}
|
||||
batch_worker_stop = false;
|
||||
batch_worker_thread = std::thread(batch_worker_loop);
|
||||
batch_worker_thread.detach();
|
||||
batch_worker_started = true;
|
||||
}
|
||||
|
||||
bool gpttype_batch_generate_enabled()
|
||||
{
|
||||
return continuous_batching_slots > 1 && file_format == FileFormat::GGUF_GENERIC && llama_ctx_v4 && kcpp_data;
|
||||
}
|
||||
|
||||
int gpttype_batch_generate_submit(const generation_inputs inputs)
|
||||
{
|
||||
if(!batch_inputs_eligible(inputs))
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(batch_mutex);
|
||||
if(batch_legacy_active || batch_legacy_waiting > 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
auto req = std::make_unique<BatchGenerateRequest>();
|
||||
req->id = batch_next_request_id++;
|
||||
req->prompt = inputs.prompt ? inputs.prompt : "";
|
||||
req->max_context_length = inputs.max_context_length;
|
||||
req->max_length = inputs.max_length;
|
||||
req->seed = inputs.seed;
|
||||
req->temperature = inputs.temperature;
|
||||
req->top_k = inputs.top_k;
|
||||
req->top_p = inputs.top_p;
|
||||
req->min_p = inputs.min_p;
|
||||
req->typical_p = inputs.typical_p;
|
||||
req->rep_pen = inputs.rep_pen;
|
||||
req->rep_pen_slope = inputs.rep_pen_slope;
|
||||
req->rep_pen_range = inputs.rep_pen_range;
|
||||
req->presence_penalty = inputs.presence_penalty;
|
||||
req->allow_eos_token = inputs.allow_eos_token;
|
||||
req->bypass_eos_token = inputs.bypass_eos_token;
|
||||
req->render_special = inputs.render_special;
|
||||
for(int i = 0; i < inputs.stop_sequence_len; ++i)
|
||||
{
|
||||
if(inputs.stop_sequence[i])
|
||||
{
|
||||
req->stop_sequences.emplace_back(inputs.stop_sequence[i]);
|
||||
}
|
||||
}
|
||||
int request_id = req->id;
|
||||
batch_requests.emplace_back(std::move(req));
|
||||
batch_waiting.push_back(request_id);
|
||||
batch_start_worker_locked();
|
||||
batch_cv.notify_all();
|
||||
return request_id;
|
||||
}
|
||||
|
||||
bool gpttype_batch_generate_has_finished(int request_id)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(batch_mutex);
|
||||
BatchGenerateRequest * req = batch_find_request_locked(request_id);
|
||||
return !req || !batch_is_live_state(req->state);
|
||||
}
|
||||
|
||||
int gpttype_batch_generate_stream_count(int request_id)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(batch_mutex);
|
||||
BatchGenerateRequest * req = batch_find_request_locked(request_id);
|
||||
return req ? req->generated_pieces.size() : 0;
|
||||
}
|
||||
|
||||
const char * gpttype_batch_generate_new_token(int request_id, int idx)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(batch_mutex);
|
||||
BatchGenerateRequest * req = batch_find_request_locked(request_id);
|
||||
if(!req || idx < 0 || idx >= (int) req->generated_pieces.size())
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
return req->generated_pieces[idx].c_str();
|
||||
}
|
||||
|
||||
const char * gpttype_batch_generate_pending_output(int request_id)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(batch_mutex);
|
||||
BatchGenerateRequest * req = batch_find_request_locked(request_id);
|
||||
if(!req)
|
||||
{
|
||||
return batch_empty_string.c_str();
|
||||
}
|
||||
return req->output.c_str();
|
||||
}
|
||||
|
||||
generation_outputs gpttype_batch_generate_result(int request_id)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(batch_mutex);
|
||||
batch_cv.wait(lock, [request_id](){
|
||||
BatchGenerateRequest * req = batch_find_request_locked(request_id);
|
||||
return !req || !batch_is_live_state(req->state);
|
||||
});
|
||||
BatchGenerateRequest * req = batch_find_request_locked(request_id);
|
||||
if(!req)
|
||||
{
|
||||
generation_outputs output;
|
||||
output.status = 0;
|
||||
output.stopreason = stop_reason::ERROR_ENCOUNTERED;
|
||||
output.prompt_tokens = 0;
|
||||
output.completion_tokens = 0;
|
||||
output.text = batch_empty_string.c_str();
|
||||
return output;
|
||||
}
|
||||
req->result.text = req->output.c_str();
|
||||
return req->result;
|
||||
}
|
||||
|
||||
bool gpttype_batch_generate_abort(int request_id)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(batch_mutex);
|
||||
BatchGenerateRequest * req = batch_find_request_locked(request_id);
|
||||
if(!req)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
req->abort_requested = true;
|
||||
batch_cv.notify_all();
|
||||
return true;
|
||||
}
|
||||
|
||||
void gpttype_batch_generate_release(int request_id)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(batch_mutex);
|
||||
batch_requests.erase(std::remove_if(batch_requests.begin(), batch_requests.end(), [request_id](const std::unique_ptr<BatchGenerateRequest> & req){
|
||||
return req && req->id == request_id && !batch_is_live_state(req->state);
|
||||
}), batch_requests.end());
|
||||
batch_cv.notify_all();
|
||||
}
|
||||
|
||||
std::string gpttype_get_chat_template()
|
||||
{
|
||||
if(kcpp_data==nullptr)
|
||||
|
|
@ -3547,6 +4301,7 @@ int smartcache_quick_snapshot(int specific_slot = -1)
|
|||
|
||||
generation_outputs gpttype_generate(const generation_inputs inputs)
|
||||
{
|
||||
BatchLegacyGuard batch_legacy_guard;
|
||||
generation_outputs output;
|
||||
|
||||
if(kcpp_data==nullptr)
|
||||
|
|
|
|||
138
koboldcpp.py
138
koboldcpp.py
|
|
@ -110,6 +110,9 @@ maxctx = 8192
|
|||
maxhordectx = 0 #set to whatever maxctx is if 0
|
||||
maxhordelen = 1024
|
||||
modelbusy = threading.Lock()
|
||||
batched_lock = threading.Lock()
|
||||
batched_cond = threading.Condition(batched_lock)
|
||||
batched_request_runner_count = 0 #incremented when a batched request is running, prevents all non-batched requests
|
||||
requestsinqueue = 0
|
||||
ratelimitlookup = {}
|
||||
defaultport = 5001
|
||||
|
|
@ -284,7 +287,8 @@ class load_model_inputs(ctypes.Structure):
|
|||
("lora_multiplier", ctypes.c_float),
|
||||
("devices_override", ctypes.c_char_p),
|
||||
("quiet", ctypes.c_bool),
|
||||
("debugmode", ctypes.c_int)]
|
||||
("debugmode", ctypes.c_int),
|
||||
("continuous_batching_slots", ctypes.c_int)]
|
||||
|
||||
class generation_inputs(ctypes.Structure):
|
||||
_fields_ = [("seed", ctypes.c_int),
|
||||
|
|
@ -893,6 +897,23 @@ def init_library():
|
|||
handle.new_token.argtypes = [ctypes.c_int]
|
||||
handle.get_stream_count.restype = ctypes.c_int
|
||||
handle.has_finished.restype = ctypes.c_bool
|
||||
handle.batch_generate_enabled.restype = ctypes.c_bool
|
||||
handle.batch_generate_submit.argtypes = [generation_inputs]
|
||||
handle.batch_generate_submit.restype = ctypes.c_int
|
||||
handle.batch_generate_has_finished.argtypes = [ctypes.c_int]
|
||||
handle.batch_generate_has_finished.restype = ctypes.c_bool
|
||||
handle.batch_generate_stream_count.argtypes = [ctypes.c_int]
|
||||
handle.batch_generate_stream_count.restype = ctypes.c_int
|
||||
handle.batch_generate_new_token.argtypes = [ctypes.c_int, ctypes.c_int]
|
||||
handle.batch_generate_new_token.restype = ctypes.c_char_p
|
||||
handle.batch_generate_pending_output.argtypes = [ctypes.c_int]
|
||||
handle.batch_generate_pending_output.restype = ctypes.c_char_p
|
||||
handle.batch_generate_result.argtypes = [ctypes.c_int]
|
||||
handle.batch_generate_result.restype = generation_outputs
|
||||
handle.batch_generate_abort.argtypes = [ctypes.c_int]
|
||||
handle.batch_generate_abort.restype = ctypes.c_bool
|
||||
handle.batch_generate_release.argtypes = [ctypes.c_int]
|
||||
handle.batch_generate_release.restype = None
|
||||
handle.has_audio_support.restype = ctypes.c_bool
|
||||
handle.has_vision_support.restype = ctypes.c_bool
|
||||
handle.get_last_eval_time.restype = ctypes.c_float
|
||||
|
|
@ -1912,6 +1933,9 @@ def load_model(model_filename):
|
|||
inputs.visionmintokens = vmintk
|
||||
inputs.visionmaxtokens = vmaxtk
|
||||
inputs.use_smartcontext = args.smartcontext
|
||||
if getattr(args, "continuous_batching", 0) > 1 and not args.noshift:
|
||||
print("\nWarning: Continuous batching is enabled, so context shifting has been disabled automatically.\n")
|
||||
args.noshift = True
|
||||
inputs.use_contextshift = (0 if args.noshift else 1)
|
||||
inputs.use_fastforward = (0 if args.nofastforward else 1)
|
||||
inputs.flash_attention = (False if args.noflashattention else True)
|
||||
|
|
@ -1984,6 +2008,7 @@ def load_model(model_filename):
|
|||
savestate_limit = sclimit
|
||||
inputs.smartcacheslots = sclimit
|
||||
inputs.pipelineparallel = (not args.nopipelineparallel)
|
||||
inputs.continuous_batching_slots = int(args.continuous_batching) if hasattr(args, "continuous_batching") else 0
|
||||
inputs = set_backend_props(inputs)
|
||||
ret = handle.load_model(inputs)
|
||||
return ret
|
||||
|
|
@ -2230,10 +2255,26 @@ def generate(genparams, stream_flag=False):
|
|||
pendingabortkey = ""
|
||||
return {"text":"","status":-1,"stopreason":-1, "prompt_tokens":0, "completion_tokens": 0, "total_tokens": 0}
|
||||
else:
|
||||
ret = handle.generate(inputs)
|
||||
batch_request_id = -1
|
||||
if getattr(args, "continuous_batching", 0) > 1:
|
||||
try:
|
||||
batch_request_id = handle.batch_generate_submit(inputs)
|
||||
except Exception:
|
||||
batch_request_id = -1
|
||||
if batch_request_id >= 0:
|
||||
genparams['_batch_request_id'] = batch_request_id
|
||||
ret = handle.batch_generate_result(batch_request_id)
|
||||
else:
|
||||
genparams['_batch_fallback'] = True
|
||||
ret = handle.generate(inputs)
|
||||
outstr = ""
|
||||
if ret.status==1:
|
||||
outstr = ret.text.decode("UTF-8","ignore")
|
||||
if batch_request_id >= 0 and not stream_flag:
|
||||
handle.batch_generate_release(batch_request_id)
|
||||
genparams.pop('_batch_request_id', None)
|
||||
genparams.pop('_batch_expected', None)
|
||||
genparams.pop('_batch_fallback', None)
|
||||
if trimstop:
|
||||
for trim_str in stop_sequence:
|
||||
sindex = outstr.find(trim_str)
|
||||
|
|
@ -2241,6 +2282,32 @@ def generate(genparams, stream_flag=False):
|
|||
outstr = outstr[:sindex]
|
||||
return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason,"prompt_tokens":ret.prompt_tokens, "completion_tokens": ret.completion_tokens}
|
||||
|
||||
def continuous_batching_python_eligible(genparams, api_format):
|
||||
if getattr(args, "continuous_batching", 0) <= 1 or api_format <= 0:
|
||||
return False
|
||||
model_path = str(getattr(args, "model_param", "") or "").lower()
|
||||
if model_path and not model_path.endswith(".gguf"):
|
||||
return False
|
||||
if not getattr(args, "noshift", False) or getattr(args, "smartcontext", False) or getattr(args, "draftmodel", "") or getattr(args, "mmproj", "") or getattr(args, "enableguidance", False):
|
||||
return False
|
||||
if genparams.get("memory") or genparams.get("negative_prompt") or genparams.get("images") or genparams.get("audio"):
|
||||
return False
|
||||
if genparams.get("ban_eos_token", False):
|
||||
return False
|
||||
if genparams.get("grammar") or genparams.get("grammar_retain_state") or genparams.get("logit_bias") or genparams.get("banned_tokens") or genparams.get("banned_strings"):
|
||||
return False
|
||||
if tryparsefloat(genparams.get("dry_multiplier", 0), 0) or tryparseint(genparams.get("mirostat", 0), 0) or tryparsefloat(genparams.get("xtc_probability", 0), 0) or tryparsefloat(genparams.get("nsigma", 0), 0):
|
||||
return False
|
||||
if tryparsefloat(genparams.get("smoothing_factor", 0), 0) or tryparsefloat(genparams.get("adaptive_target", -1), -1) > 0 or genparams.get("using_openai_tools", False):
|
||||
return False
|
||||
if tryparsefloat(genparams.get("top_a", 0), 0) or tryparsefloat(genparams.get("tfs", 1), 1) != 1 or tryparsefloat(genparams.get("dynatemp_range", 0), 0):
|
||||
return False
|
||||
if genparams.get("sampler_order") and genparams.get("sampler_order") != [6, 0, 1, 3, 4, 2, 5]:
|
||||
return False
|
||||
if genparams.get("reasoning_effort"):
|
||||
return False
|
||||
return True
|
||||
|
||||
def sd_get_info():
|
||||
info = handle.sd_get_info()
|
||||
if info.status == 0:
|
||||
|
|
@ -4922,19 +4989,29 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
incomplete_token_buffer = bytearray()
|
||||
async_sleep_short = 0.02
|
||||
await asyncio.sleep(0.35) #anti race condition, prevent check from overtaking generate
|
||||
batch_request_id = genparams.get('_batch_request_id', -1)
|
||||
batch_final_result = None
|
||||
|
||||
try:
|
||||
tokenReserve = "" #keeps fully formed tokens that we cannot send out yet
|
||||
while True:
|
||||
streamDone = handle.has_finished() #exit next loop on done
|
||||
if batch_request_id < 0:
|
||||
batch_request_id = genparams.get('_batch_request_id', -1)
|
||||
if genparams.get('_batch_expected', False) and batch_request_id < 0 and not genparams.get('_batch_fallback', False):
|
||||
await asyncio.sleep(async_sleep_short)
|
||||
continue
|
||||
using_batch_stream = batch_request_id >= 0
|
||||
streamDone = handle.batch_generate_has_finished(batch_request_id) if using_batch_stream else handle.has_finished() #exit next loop on done
|
||||
if streamDone:
|
||||
sr = handle.get_last_stop_reason()
|
||||
if using_batch_stream and batch_final_result is None:
|
||||
batch_final_result = handle.batch_generate_result(batch_request_id)
|
||||
sr = batch_final_result.stopreason if using_batch_stream else handle.get_last_stop_reason()
|
||||
currfinishreason = "error" if sr==-2 else ("length" if (sr!=1) else "stop")
|
||||
prompttokens = handle.get_last_input_count()
|
||||
prompttokens = batch_final_result.prompt_tokens if using_batch_stream else handle.get_last_input_count()
|
||||
tokenStr = ""
|
||||
streamcount = handle.get_stream_count()
|
||||
streamcount = handle.batch_generate_stream_count(batch_request_id) if using_batch_stream else handle.get_stream_count()
|
||||
while current_token < streamcount:
|
||||
token = handle.new_token(current_token)
|
||||
token = handle.batch_generate_new_token(batch_request_id, current_token) if using_batch_stream else handle.new_token(current_token)
|
||||
|
||||
if token is None: # Token isnt ready yet, received nullpointer
|
||||
break
|
||||
|
|
@ -5147,7 +5224,8 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
if streamDone:
|
||||
# content_part.done, reply full text
|
||||
await asyncio.sleep(async_sleep_short)
|
||||
finaltxt = handle.get_pending_output().decode("UTF-8", "ignore")
|
||||
finalraw = handle.batch_generate_pending_output(batch_request_id) if using_batch_stream else handle.get_pending_output()
|
||||
finaltxt = finalraw.decode("UTF-8", "ignore")
|
||||
await asyncio.sleep(async_sleep_short)
|
||||
done_event = json.dumps({"type": "response.output_text.done", "item_id": item_id, "output_index": 0, "sequence_number":rseq_num, "content_index": 0, "text": finaltxt})
|
||||
rseq_num += 1
|
||||
|
|
@ -5156,7 +5234,7 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
item_done = json.dumps({"type": "response.output_item.done", "output_index": 0, "sequence_number":rseq_num, "item": { "type": "message", "id": item_id, "status": "completed", "role": "assistant", "content": [{"type": "output_text", "text": finaltxt, "annotations": [], "logprobs": []}]}})
|
||||
rseq_num += 1
|
||||
await self.send_oai_responses_sse_event("response.output_item.done",item_done)
|
||||
usage_pp = handle.get_last_input_count()
|
||||
usage_pp = batch_final_result.prompt_tokens if using_batch_stream else handle.get_last_input_count()
|
||||
usage_gen = current_token
|
||||
res = self.prepare_basic_responses_body(resp_id,genparams)
|
||||
res["completed_at"] = int(time.time())
|
||||
|
|
@ -5204,8 +5282,17 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
except Exception as ex:
|
||||
print("Token streaming was interrupted or aborted!")
|
||||
print(ex)
|
||||
handle.abort_generate()
|
||||
if batch_request_id >= 0:
|
||||
handle.batch_generate_abort(batch_request_id)
|
||||
else:
|
||||
handle.abort_generate()
|
||||
await asyncio.sleep(0.2) #short delay
|
||||
finally:
|
||||
if batch_request_id >= 0:
|
||||
handle.batch_generate_release(batch_request_id)
|
||||
genparams.pop('_batch_request_id', None)
|
||||
genparams.pop('_batch_expected', None)
|
||||
genparams.pop('_batch_fallback', None)
|
||||
|
||||
# flush buffers, sleep a bit to make sure all data sent, and then force close the connection
|
||||
self.wfile.flush()
|
||||
|
|
@ -5271,7 +5358,7 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
pass
|
||||
|
||||
def get_multiplayer_idle_state(self,userid):
|
||||
if modelbusy.locked():
|
||||
if modelbusy.locked() or batched_request_runner_count>0:
|
||||
return False
|
||||
for key, value in multiplayer_lastactive.items():
|
||||
if key!=userid and time.time()-value<6: #6s to idle
|
||||
|
|
@ -5308,7 +5395,7 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
return True
|
||||
|
||||
def noscript_webui(self):
|
||||
global modelbusy, sslvalid
|
||||
global modelbusy, sslvalid, batched_request_runner_count
|
||||
parsed_url = urllib.parse.urlparse(self.path)
|
||||
parsed_dict = urllib.parse.parse_qs(parsed_url.query)
|
||||
reply = ""
|
||||
|
|
@ -5348,7 +5435,7 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
else:
|
||||
gencommand = False
|
||||
|
||||
if modelbusy.locked():
|
||||
if modelbusy.locked() or batched_request_runner_count>0:
|
||||
status = "Model is currently busy, try again later."
|
||||
elif gencommand:
|
||||
if prompt=="" or max_length<=0:
|
||||
|
|
@ -5587,7 +5674,7 @@ Change Mode<br>
|
|||
"total_tts_gens": totalttsgens,
|
||||
"total_transcribe_gens": totaltranscribegens,
|
||||
"queue": requestsinqueue,
|
||||
"idle": (0 if modelbusy.locked() else 1),
|
||||
"idle": (0 if (modelbusy.locked() or batched_request_runner_count>0) else 1),
|
||||
"hordeexitcounter": exitcounter,
|
||||
"uptime": uptime,
|
||||
"idletime": idletime,
|
||||
|
|
@ -5850,7 +5937,7 @@ Change Mode<br>
|
|||
|
||||
def do_POST(self):
|
||||
global thinkformats
|
||||
global modelbusy, requestsinqueue, currentusergenkey, totalgens, pendingabortkey, lastuploadedcomfyimg, lastgeneratedcomfyimg, multiplayer_turn_major, multiplayer_turn_minor, multiplayer_story_data_compressed, multiplayer_dataformat, multiplayer_lastactive, net_save_slots, has_vision_support, savestate_limit, mcp_lock
|
||||
global modelbusy, batched_request_runner_count, requestsinqueue, currentusergenkey, totalgens, pendingabortkey, lastuploadedcomfyimg, lastgeneratedcomfyimg, multiplayer_turn_major, multiplayer_turn_minor, multiplayer_story_data_compressed, multiplayer_dataformat, multiplayer_lastactive, net_save_slots, has_vision_support, savestate_limit, mcp_lock
|
||||
global autoswapmode, textName, sttName, ttsName, embedName, musicName, imageName, mmprojName
|
||||
contlenstr = self.headers['content-length']
|
||||
content_length = 0
|
||||
|
|
@ -6323,6 +6410,7 @@ Change Mode<br>
|
|||
"type": "service_unavailable",
|
||||
}}).encode())
|
||||
return
|
||||
is_batchable_req = False
|
||||
if reqblocking:
|
||||
requestsinqueue = (requestsinqueue - 1) if requestsinqueue > 0 else 0
|
||||
|
||||
|
|
@ -6531,10 +6619,22 @@ Change Mode<br>
|
|||
if args.foreground:
|
||||
bring_terminal_to_foreground()
|
||||
|
||||
#if it's a non-batchable request and we already have batching ongoing, stall this request
|
||||
if batched_request_runner_count > 0 and not continuous_batching_python_eligible(genparams, api_format):
|
||||
with batched_cond:
|
||||
while batched_request_runner_count > 0:
|
||||
batched_cond.wait()
|
||||
|
||||
if api_format > 0: #text gen
|
||||
# Check if streaming chat completions, if so, set stream mode to true
|
||||
if (api_format == 4 or api_format == 3 or api_format == 8 or api_format == 9) and "stream" in genparams and genparams["stream"]:
|
||||
sse_stream_flag = True
|
||||
if continuous_batching_python_eligible(genparams, api_format):
|
||||
genparams['_batch_expected'] = True
|
||||
modelbusy.release()
|
||||
is_batchable_req = True
|
||||
with batched_cond:
|
||||
batched_request_runner_count += 1
|
||||
|
||||
gendat = asyncio.run(self.handle_request(genparams, api_format, sse_stream_flag))
|
||||
|
||||
|
|
@ -6887,7 +6987,12 @@ Change Mode<br>
|
|||
|
||||
finally:
|
||||
time.sleep(0.05)
|
||||
modelbusy.release()
|
||||
if is_batchable_req:
|
||||
with batched_cond:
|
||||
batched_request_runner_count -= 1
|
||||
batched_cond.notify_all()
|
||||
else:
|
||||
modelbusy.release()
|
||||
|
||||
self.send_response(404)
|
||||
self.end_headers(content_type='text/html')
|
||||
|
|
@ -11380,6 +11485,7 @@ if __name__ == '__main__':
|
|||
advparser.add_argument("--analyze", metavar=('[filename]'), help="Reads the metadata, weight types and tensor names in any GGUF file.", default="")
|
||||
advparser.add_argument("--maingpu","--main-gpu","-mg", help="Only used in a multi-gpu setup. Sets the index of the main GPU that will be used.",metavar=('[Device ID]'), type=int, default=-1)
|
||||
advparser.add_argument("--batchsize","--blasbatchsize","--batch-size","-b", help="Sets the batch size used in batched processing (default 512). Setting it to -1 disables batched mode, but keeps other benefits like GPU offload.", type=int,choices=[-1,16,32,64,128,256,512,1024,2048,4096], default=512)
|
||||
advparser.add_argument("--continuous-batching","--contbatch", help=argparse.SUPPRESS, metavar=('[slots]'), type=check_range(int,0,64), default=0)
|
||||
advparser.add_argument("--blasthreads","--batchthreads","--threadsbatch","--threads-batch", help="Use a different number of threads during batching if specified. Otherwise, has the same value as --threads",metavar=('[threads]'), type=int, default=0)
|
||||
advparser.add_argument("--splitmode","-sm","--split-mode", help="How to split the model across multiple GPUs", metavar=('[split mode]'), type=str, choices=splitmode_choices, default=splitmode_choices[0])
|
||||
advparser.add_argument("--nommq", help="Disables MMQ, only used for cuda backend. This flag may be removed in future.", action='store_true')
|
||||
|
|
|
|||
|
|
@ -84,6 +84,15 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
generation_outputs gpttype_generate(const generation_inputs inputs);
|
||||
bool gpttype_generate_abort();
|
||||
std::string gpttype_get_chat_template();
|
||||
bool gpttype_batch_generate_enabled();
|
||||
int gpttype_batch_generate_submit(const generation_inputs inputs);
|
||||
bool gpttype_batch_generate_has_finished(int request_id);
|
||||
int gpttype_batch_generate_stream_count(int request_id);
|
||||
const char * gpttype_batch_generate_new_token(int request_id, int idx);
|
||||
const char * gpttype_batch_generate_pending_output(int request_id);
|
||||
generation_outputs gpttype_batch_generate_result(int request_id);
|
||||
bool gpttype_batch_generate_abort(int request_id);
|
||||
void gpttype_batch_generate_release(int request_id);
|
||||
|
||||
const std::string & gpttype_get_pending_output();
|
||||
std::vector<int> gpttype_get_token_arr(const std::string & input, bool addbos);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue