memory : handle kv_unified for hybrid models (#15050)

This commit is contained in:
compilade 2025-08-03 15:43:07 -04:00 committed by GitHub
parent 97366dc6ab
commit 11a3811164
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 4 additions and 1 deletions

View file

@ -25,6 +25,7 @@ llama_memory_hybrid::llama_memory_hybrid(
/* common */ /* common */
uint32_t n_seq_max, uint32_t n_seq_max,
bool offload, bool offload,
bool unified,
/* layer filters */ /* layer filters */
layer_filter_cb && filter_attn, layer_filter_cb && filter_attn,
layer_filter_cb && filter_recr) : layer_filter_cb && filter_recr) :
@ -38,7 +39,7 @@ llama_memory_hybrid::llama_memory_hybrid(
type_v, type_v,
v_trans, v_trans,
offload, offload,
1, unified,
kv_size, kv_size,
n_seq_max, n_seq_max,
n_pad, n_pad,

View file

@ -39,6 +39,7 @@ public:
/* common */ /* common */
uint32_t n_seq_max, uint32_t n_seq_max,
bool offload, bool offload,
bool unified,
/* layer filters */ /* layer filters */
layer_filter_cb && filter_attn = nullptr, layer_filter_cb && filter_attn = nullptr,
layer_filter_cb && filter_recr = nullptr); layer_filter_cb && filter_recr = nullptr);

View file

@ -17598,6 +17598,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
/* n_seq_max */ cparams.n_seq_max, /* n_seq_max */ cparams.n_seq_max,
/* offload */ cparams.offload_kqv, /* offload */ cparams.offload_kqv,
/* unified */ cparams.kv_unified,
/* filter_attn */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr, /* filter_attn */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr,
/* filter_recr */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr); /* filter_recr */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr);
} else { } else {