fixed swa pp bug by retrying smaller batches

This commit is contained in:
Concedo 2025-07-21 23:34:22 +08:00
parent 6d50def409
commit 9f4d0f6ccf
4 changed files with 39 additions and 10 deletions

View file

@ -2174,14 +2174,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
llama_model_params model_params = llama_model_default_params();
llama_context_params llama_ctx_params = llama_context_default_params();
llama_ctx_params.n_ctx = clamped_max_context_length;
if(kcpp_data->use_contextshift)
{
llama_ctx_params.n_ctx += extra_context_handle_fragmentation;
}
else
{
llama_ctx_params.n_ctx += (extra_context_handle_fragmentation/2);
}
llama_ctx_params.n_ctx += extra_context_handle_fragmentation;
llama_ctx_params.offload_kqv = !inputs.low_vram;
llama_ctx_params.kv_unified = true;
@ -3844,7 +3837,31 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
{
draft_used = false;
kcpp_embd_batch batch = kcpp_embd_batch(embd, n_past, use_mrope, false);
evalres = (llama_decode(llama_ctx_v4, batch.batch)==0);
int32_t decode_status = llama_decode(llama_ctx_v4, batch.batch);
if(decode_status==1 && embd.size()>128)
{
printf("Couldn't find a big KV slot. Retry with smaller batch size of 128...\n");
std::vector<std::vector<gpt_vocab::id>> parts = split_big_vector(embd,128);
int temp_past = n_past;
evalres = true;
for(int p=0;p<parts.size();++p)
{
std::vector<gpt_vocab::id> chunk = parts[p];
kcpp_embd_batch smallbatch = kcpp_embd_batch(chunk, temp_past, use_mrope, false);
int32_t decode_status2 = llama_decode(llama_ctx_v4, smallbatch.batch);
if(debugmode==1 && !is_quiet)
{
printf("Retry chunk: %d at %d... status: %s\n",chunk.size(),temp_past,(decode_status2==0?"ok":"fail"));
}
evalres = (evalres && (decode_status2==0));
temp_past += chunk.size();
}
}
else
{
evalres = (decode_status==0);
}
if(draft_ctx)
{
evalres = (evalres && (llama_decode(draft_ctx, batch.batch)==0));
@ -3928,6 +3945,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
if (!evalres)
{
fprintf(stderr, "\nFailed to predict at token position %d! Check your context buffer sizes!\n",n_past);
media_composite_image_signature = ""; //force invalidate
output.text = nullptr;
output.status = 0;
output.prompt_tokens = output.completion_tokens = 0;

View file

@ -356,6 +356,16 @@ std::string get_timestamp_str()
return timestamp;
}
//split a big vector into multiple small vectors of chunk size or less
std::vector<std::vector<int>> split_big_vector(const std::vector<int>& big_arr, size_t chunk_size) {
std::vector<std::vector<int>> small_arrs;
for (size_t i = 0; i < big_arr.size(); i += chunk_size) {
size_t end = std::min(i + chunk_size, big_arr.size());
small_arrs.emplace_back(big_arr.begin() + i, big_arr.begin() + end);
}
return small_arrs;
}
std::vector<float> resample_wav(const std::vector<float>& input, uint32_t input_rate, uint32_t output_rate) {
size_t input_size = input.size();

View file

@ -61,6 +61,7 @@ std::string kcpp_base64_encode(const unsigned char* data, unsigned int data_leng
std::string kcpp_base64_encode(const std::string &data);
std::string get_timestamp_str();
std::vector<std::vector<int>> split_big_vector(const std::vector<int>& big_arr, size_t chunk_size);
std::vector<float> resample_wav(const std::vector<float>& input, uint32_t input_rate, uint32_t output_rate);
int32_t kcpp_quick_sample(float * logits, const int n_logits, const std::vector<int32_t> & last_n_tokens, float rep_pen, float top_p, int top_k, float temp, std::mt19937 & rng);

View file

@ -31,7 +31,7 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
//kcpp: pad the swa kv cache as well, similar to extra_context_handle_fragmentation
size_swa += 32;
size_swa += 128;
size_swa = GGML_PAD(size_swa, n_pad);
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size