added support for added memory and gemma and glm prompt fixes for batching mode

This commit is contained in:
Concedo 2026-05-10 23:39:03 +08:00
parent 33ca75d56f
commit bfaddd7a3b
2 changed files with 23 additions and 5 deletions

View file

@ -3411,6 +3411,7 @@ struct BatchGenerateRequest
int slot = -1;
BatchState state = BatchState::WAITING;
std::string prompt;
std::string prompt_added_memory;
std::vector<std::string> stop_sequences;
int max_context_length = 0;
int max_length = 0;
@ -3557,10 +3558,6 @@ static bool batch_inputs_eligible(const generation_inputs & inputs)
{
return false;
}
if(inputs.memory && std::string(inputs.memory).size() > 0)
{
return false;
}
if(inputs.negative_prompt && std::string(inputs.negative_prompt).size() > 0)
{
return false;
@ -3831,12 +3828,23 @@ static bool batch_claim_waiting_locked()
req->slot = slot;
req->state = BatchState::PREFILL;
batch_touched_since_legacy = true;
ApplyPromptFormatAdjustments(req->prompt_added_memory, req->prompt);
std::vector<llama_token> added_memory_tokens; //temporary buf before copying over
TokenizeString(req->prompt, req->prompt_tokens, file_format, add_bos_token);
if(req->prompt_tokens.empty())
{
TokenizeString("", req->prompt_tokens, file_format, add_bos_token);
}
if(req->prompt_added_memory!="")
{
TokenizeString(req->prompt_added_memory, added_memory_tokens, file_format, add_bos_token);
}
int n_ctx = req->max_context_length > 0 ? std::min(req->max_context_length, kcpp_data->n_ctx) : kcpp_data->n_ctx;
AppendDedicatedMemoryAndNegativePrompt(req->prompt_tokens, added_memory_tokens, std::vector<llama_token>(), req->max_length, n_ctx);
if(req->max_length > 0 && (int) req->prompt_tokens.size() + req->max_length > n_ctx)
{
int keep = std::max(1, n_ctx - req->max_length);
@ -3845,6 +3853,15 @@ static bool batch_claim_waiting_locked()
req->prompt_tokens.erase(req->prompt_tokens.begin(), req->prompt_tokens.end() - keep);
}
}
if (debugmode==1 && !is_quiet)
{
std::string outstr = "";
printf("\n\n[Debug: Dump %zu Raw Input Tokens]\n",req->prompt_tokens.size());
outstr += get_tok_vec_str(req->prompt_tokens);
printf("%s\n", RemoveBell(outstr).c_str());
}
req->prompt_token_count = req->prompt_tokens.size();
req->sampler = batch_build_sampler(*req);
for(llama_token token : req->prompt_tokens)
@ -4024,6 +4041,7 @@ int gpttype_batch_generate_submit(const generation_inputs inputs)
auto req = std::make_unique<BatchGenerateRequest>();
req->id = batch_next_request_id++;
req->prompt = inputs.prompt ? inputs.prompt : "";
req->prompt_added_memory = inputs.memory ? inputs.memory : "";
req->max_context_length = inputs.max_context_length;
req->max_length = inputs.max_length;
req->seed = inputs.seed;

View file

@ -2290,7 +2290,7 @@ def continuous_batching_python_eligible(genparams, api_format):
return False
if not getattr(args, "noshift", False) or getattr(args, "smartcontext", False) or getattr(args, "draftmodel", "") or getattr(args, "mmproj", "") or getattr(args, "enableguidance", False):
return False
if genparams.get("memory") or genparams.get("negative_prompt") or genparams.get("images") or genparams.get("audio"):
if genparams.get("negative_prompt") or genparams.get("images") or genparams.get("audio"):
return False
if genparams.get("ban_eos_token", False):
return False