not working correctly

This commit is contained in:
Concedo 2025-06-02 22:12:10 +08:00
commit 6ce85c54d6
30 changed files with 1603 additions and 724 deletions

View file

@ -8992,9 +8992,9 @@ struct llm_build_mamba : public llm_graph_context {
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const {
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
const auto kv_head = kv_self->head;
const auto kv_head = kv_state->get_head();
const int64_t d_conv = hparams.ssm_d_conv;
const int64_t d_inner = hparams.ssm_d_inner;
@ -9012,8 +9012,8 @@ struct llm_build_mamba : public llm_graph_context {
GGML_ASSERT(ubatch.equal_seqs);
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
ggml_tensor * conv_states_all = kv_self->k_l[il];
ggml_tensor * ssm_states_all = kv_self->v_l[il];
ggml_tensor * conv_states_all = kv_state->get_k_l(il);
ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
// (ab)using the KV cache to store the states
ggml_tensor * conv = build_copy_mask_state(
@ -11740,7 +11740,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const {
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
const auto n_tokens = ubatch.n_tokens;
const auto n_seqs = ubatch.n_seqs;
@ -11750,7 +11750,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
const auto n_head = n_embd / head_size;
const auto n_head_kv = hparams.n_head_kv(il);
const auto kv_head = kv_self->head;
const auto kv_head = kv_state->get_head();
const auto & layer = model.layers[il];
@ -11862,7 +11862,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
}
ggml_tensor * wkv_state = build_copy_mask_state(
gf, kv_self->v_l[il], state_copy, state_mask,
gf, kv_state->get_v_l(il), state_copy, state_mask,
hparams.n_embd_v_s(), n_seqs);
ggml_tensor * wkv_output;
@ -11881,9 +11881,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
wkv_state,
ggml_view_1d(
ctx0,
kv_self->v_l[il],
kv_state->get_v_l(il),
hparams.n_embd_v_s() * n_seqs,
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il])
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il))
)
)
);
@ -12136,7 +12136,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
ggml_tensor *& first_layer_value,
const llama_ubatch & ubatch,
int il) const {
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
const auto n_tokens = ubatch.n_tokens;
const auto n_seqs = ubatch.n_seqs;
@ -12145,7 +12145,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
const auto head_count = n_embd / head_size;
const auto n_seq_tokens = ubatch.n_seq_tokens;
const auto kv_head = kv_self->head;
const auto kv_head = kv_state->get_head();
const auto & layer = model.layers[il];
@ -12216,7 +12216,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
ggml_tensor * wkv_state = build_copy_mask_state(
gf, kv_self->v_l[il], state_copy, state_mask,
gf, kv_state->get_v_l(il), state_copy, state_mask,
hparams.n_embd_v_s(), n_seqs);
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
@ -12230,9 +12230,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
wkv_state,
ggml_view_1d(
ctx0,
kv_self->v_l[il],
kv_state->get_v_l(il),
hparams.n_embd_v_s() * n_seqs,
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il])
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il))
)
)
);
@ -13330,7 +13330,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
params.swa_full,
cparams.n_ctx,
cparams.n_seq_max,
cparams.n_batch,
cparams.n_ubatch,
padding);
} else {
GGML_ASSERT(!hparams.is_swa_any());
@ -13693,6 +13693,10 @@ int32_t llama_model_n_head_kv(const llama_model * model) {
return model->hparams.n_head_kv();
}
int32_t llama_model_n_swa(const llama_model * model) {
return model->hparams.n_swa;
}
// deprecated
int32_t llama_n_ctx_train(const llama_model * model) {
return llama_model_n_ctx_train(model);