diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 5e8a4a740..3af7aff70 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2537,6 +2537,7 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses const int64_t H = v->ne[1]; const int64_t n_tokens = v->ne[2]; const int64_t n_seqs = v->ne[3]; + const int64_t K = state->ne[1]; if (S_v <= 0 || S_v > 128 || H <= 0 || n_tokens <= 0 || n_seqs <= 0) { return false; @@ -2549,10 +2550,10 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) { return false; } - if (ggml_nelements(state) != S_v * S_v * H * n_seqs) { + if (ggml_nelements(state) != S_v * S_v * H * n_seqs * K) { return false; } - if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs) { + if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) { return false; } diff --git a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c index 2e84badc9..c4d08bb21 100644 --- a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +++ b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c @@ -586,6 +586,7 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const uint32_t H = v->ne[1]; const uint32_t n_tokens = v->ne[2]; const uint32_t n_seqs = v->ne[3]; + const uint32_t K = state->ne[1]; const uint32_t total_rows = H * n_seqs; if (ith >= total_rows) { @@ -606,6 +607,10 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_sums[4] __attribute__((aligned(128))); + const uint64_t state_seq_stride = state->nb[2] / sizeof(float); + const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs; + const int64_t shift = (int64_t) n_tokens - (int64_t) K; + for (uint32_t ir = ith; ir < total_rows; ir += nth) { const uint32_t iv1 = ir % H; const uint32_t iv3 = ir / H; @@ -615,8 +620,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const uint32_t iq3 = iv3 / rq3; const uint32_t ik3 = iv3 / rk3; - float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - const float * s_in = state_in_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + const float * s_in = state_in_base + (uint64_t) iv3 * state_seq_stride + (uint64_t) iv1 * S_v * S_v; memcpy(s_out, s_in, gctx->state_bytes); float * s_work = s_out; @@ -689,6 +694,16 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo } } + if (K > 1) { + const int64_t target_slot = (int64_t) t - shift; + if (target_slot >= 0 && target_slot < (int64_t) K) { + float * curr_state_o = state_out_base + (uint64_t) target_slot * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + if (curr_state_o != s_work) { + memcpy(curr_state_o, s_work, gctx->state_bytes); + } + } + } + attn_data += (uint64_t) S_v * H; } } @@ -709,6 +724,7 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo const uint32_t S_v = v->ne[0]; const uint32_t H = v->ne[1]; const uint32_t n_seqs = v->ne[3]; + const uint32_t K = state->ne[1]; const uint32_t total_rows = H * n_seqs; if (ith >= total_rows) { @@ -736,6 +752,9 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo spad = gctx->vtcm_state_base + gctx->vtcm_state_per_thread * ith; } + const uint64_t state_seq_stride = state->nb[2] / sizeof(float); + const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs; + for (uint32_t ir = ith; ir < total_rows; ir += nth) { const uint32_t iv1 = ir % H; const uint32_t iv3 = ir / H; @@ -745,8 +764,8 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo const uint32_t iq3 = iv3 / rq3; const uint32_t ik3 = iv3 / rk3; - float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - const float * s_in = state_in_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + const float * s_in = state_in_base + (uint64_t) iv3 * state_seq_stride + (uint64_t) iv1 * S_v * S_v; float * s_work; if (spad) { @@ -901,6 +920,7 @@ int op_gated_delta_net(struct htp_ops_context * octx) { const uint32_t H = v->ne[1]; const uint32_t n_tokens = v->ne[2]; const uint32_t n_seqs = v->ne[3]; + const uint32_t K = state->ne[1]; if (S_v == 0 || S_v > HTP_GDN_MAX_SV || H == 0 || n_tokens == 0 || n_seqs == 0) { return HTP_STATUS_NO_SUPPORT; @@ -913,10 +933,10 @@ int op_gated_delta_net(struct htp_ops_context * octx) { (n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) { return HTP_STATUS_NO_SUPPORT; } - if (state->ne[0] * state->ne[1] * state->ne[2] * state->ne[3] != S_v * S_v * H * n_seqs) { + if (state->ne[0] * state->ne[2] * state->ne[3] != S_v * S_v * H * n_seqs) { return HTP_STATUS_NO_SUPPORT; } - if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs) { + if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) { return HTP_STATUS_NO_SUPPORT; }