graph : fix KQ mask, lora, cvec reuse checks (#19644)

* graph : fix KQ mask reuse condition

* cont : dedup KQ mask build and can_reuse

* cont : fix build

* graph : fix adapter check for reuse
This commit is contained in:
Georgi Gerganov 2026-02-16 09:21:11 +02:00 committed by GitHub
parent 267ba5a1d9
commit d5dfc33027
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 67 additions and 54 deletions

View file

@ -22,6 +22,8 @@ llama_context::llama_context(
const llama_model & model,
llama_context_params params) :
model(model),
cvec(std::make_unique<llama_adapter_cvec>()),
loras(std::make_unique<llama_adapter_loras>()),
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
// TODO warning when creating llama_context with awkward ctx size that is not a power of 2,
// may need to be backend-dependent
@ -1065,11 +1067,11 @@ void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_a
return;
}
loras.clear();
loras.reset(new llama_adapter_loras());
for (size_t i = 0; i < n_adapters; i ++) {
if (scales[i] != 0.0f) {
loras[adapters[i]] = scales[i];
loras->insert({adapters[i], scales[i]});
}
}
@ -1079,14 +1081,14 @@ void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_a
bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) {
LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters);
if (n_adapters != loras.size()) {
if (n_adapters != loras->size()) {
return false;
}
for (size_t i = 0; i < n_adapters; i ++) {
auto it = loras.find(adapters[i]);
auto it = loras->find(adapters[i]);
if (it == loras.end() || it->second != scales[i]) {
if (it == loras->end() || it->second != scales[i]) {
return false;
}
}
@ -1104,7 +1106,7 @@ bool llama_context::set_adapter_cvec(
// TODO: should we reserve?
return cvec.apply(model, data, len, n_embd, il_start, il_end);
return cvec->apply(model, data, len, n_embd, il_start, il_end);
}
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
@ -2081,8 +2083,8 @@ llm_graph_params llama_context::graph_params(
/*.gtype =*/ gtype,
/*.sched =*/ sched.get(),
/*.backend_cpu =*/ backend_cpu,
/*.cvec =*/ &cvec,
/*.loras =*/ &loras,
/*.cvec =*/ cvec.get(),
/*.loras =*/ loras.get(),
/*.mctx =*/ mctx,
/*.cross =*/ &cross,
/*.samplers =*/ sampling.samplers,