mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-06-01 06:00:36 +00:00
Hexagon: OP_GATED_DELTA_NET K>1 support (#23531)
* K>1 state snapshot support * removed picky indent multiple of 4 fixes
This commit is contained in:
parent
8ad8aef447
commit
939a7dd648
2 changed files with 29 additions and 8 deletions
|
|
@ -2537,6 +2537,7 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses
|
|||
const int64_t H = v->ne[1];
|
||||
const int64_t n_tokens = v->ne[2];
|
||||
const int64_t n_seqs = v->ne[3];
|
||||
const int64_t K = state->ne[1];
|
||||
|
||||
if (S_v <= 0 || S_v > 128 || H <= 0 || n_tokens <= 0 || n_seqs <= 0) {
|
||||
return false;
|
||||
|
|
@ -2549,10 +2550,10 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses
|
|||
if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) {
|
||||
return false;
|
||||
}
|
||||
if (ggml_nelements(state) != S_v * S_v * H * n_seqs) {
|
||||
if (ggml_nelements(state) != S_v * S_v * H * n_seqs * K) {
|
||||
return false;
|
||||
}
|
||||
if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs) {
|
||||
if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -586,6 +586,7 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo
|
|||
const uint32_t H = v->ne[1];
|
||||
const uint32_t n_tokens = v->ne[2];
|
||||
const uint32_t n_seqs = v->ne[3];
|
||||
const uint32_t K = state->ne[1];
|
||||
|
||||
const uint32_t total_rows = H * n_seqs;
|
||||
if (ith >= total_rows) {
|
||||
|
|
@ -606,6 +607,10 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo
|
|||
float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
|
||||
float local_sums[4] __attribute__((aligned(128)));
|
||||
|
||||
const uint64_t state_seq_stride = state->nb[2] / sizeof(float);
|
||||
const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs;
|
||||
const int64_t shift = (int64_t) n_tokens - (int64_t) K;
|
||||
|
||||
for (uint32_t ir = ith; ir < total_rows; ir += nth) {
|
||||
const uint32_t iv1 = ir % H;
|
||||
const uint32_t iv3 = ir / H;
|
||||
|
|
@ -615,8 +620,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo
|
|||
const uint32_t iq3 = iv3 / rq3;
|
||||
const uint32_t ik3 = iv3 / rk3;
|
||||
|
||||
float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
|
||||
const float * s_in = state_in_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
|
||||
float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
|
||||
const float * s_in = state_in_base + (uint64_t) iv3 * state_seq_stride + (uint64_t) iv1 * S_v * S_v;
|
||||
|
||||
memcpy(s_out, s_in, gctx->state_bytes);
|
||||
float * s_work = s_out;
|
||||
|
|
@ -689,6 +694,16 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo
|
|||
}
|
||||
}
|
||||
|
||||
if (K > 1) {
|
||||
const int64_t target_slot = (int64_t) t - shift;
|
||||
if (target_slot >= 0 && target_slot < (int64_t) K) {
|
||||
float * curr_state_o = state_out_base + (uint64_t) target_slot * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
|
||||
if (curr_state_o != s_work) {
|
||||
memcpy(curr_state_o, s_work, gctx->state_bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
attn_data += (uint64_t) S_v * H;
|
||||
}
|
||||
}
|
||||
|
|
@ -709,6 +724,7 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo
|
|||
const uint32_t S_v = v->ne[0];
|
||||
const uint32_t H = v->ne[1];
|
||||
const uint32_t n_seqs = v->ne[3];
|
||||
const uint32_t K = state->ne[1];
|
||||
|
||||
const uint32_t total_rows = H * n_seqs;
|
||||
if (ith >= total_rows) {
|
||||
|
|
@ -736,6 +752,9 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo
|
|||
spad = gctx->vtcm_state_base + gctx->vtcm_state_per_thread * ith;
|
||||
}
|
||||
|
||||
const uint64_t state_seq_stride = state->nb[2] / sizeof(float);
|
||||
const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs;
|
||||
|
||||
for (uint32_t ir = ith; ir < total_rows; ir += nth) {
|
||||
const uint32_t iv1 = ir % H;
|
||||
const uint32_t iv3 = ir / H;
|
||||
|
|
@ -745,8 +764,8 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo
|
|||
const uint32_t iq3 = iv3 / rq3;
|
||||
const uint32_t ik3 = iv3 / rk3;
|
||||
|
||||
float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
|
||||
const float * s_in = state_in_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
|
||||
float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
|
||||
const float * s_in = state_in_base + (uint64_t) iv3 * state_seq_stride + (uint64_t) iv1 * S_v * S_v;
|
||||
float * s_work;
|
||||
|
||||
if (spad) {
|
||||
|
|
@ -901,6 +920,7 @@ int op_gated_delta_net(struct htp_ops_context * octx) {
|
|||
const uint32_t H = v->ne[1];
|
||||
const uint32_t n_tokens = v->ne[2];
|
||||
const uint32_t n_seqs = v->ne[3];
|
||||
const uint32_t K = state->ne[1];
|
||||
|
||||
if (S_v == 0 || S_v > HTP_GDN_MAX_SV || H == 0 || n_tokens == 0 || n_seqs == 0) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
|
|
@ -913,10 +933,10 @@ int op_gated_delta_net(struct htp_ops_context * octx) {
|
|||
(n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
if (state->ne[0] * state->ne[1] * state->ne[2] * state->ne[3] != S_v * S_v * H * n_seqs) {
|
||||
if (state->ne[0] * state->ne[2] * state->ne[3] != S_v * S_v * H * n_seqs) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs) {
|
||||
if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue