SYCL : gated_delta_net K>1 (#23174)

* sycl_gated_delta_net K>1

* editor_config
This commit is contained in:
karavayev 2026-05-22 08:48:56 -04:00 committed by GitHub
parent 8cc67efcd4
commit 56f16f235c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -6,7 +6,7 @@
#include <cmath>
template <int S_v, bool KDA>
template <int S_v, bool KDA, bool keep_rs_t>
void gated_delta_net_sycl(const float * q,
const float * k,
const float * v,
@ -28,7 +28,8 @@ void gated_delta_net_sycl(const float * q,
int64_t sb3,
const sycl::uint3 neqk1_magic,
const sycl::uint3 rq3_magic,
float scale) {
float scale,
int K) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const uint32_t h_idx = item_ct1.get_group(2);
const uint32_t sequence = item_ct1.get_group(1);
@ -43,9 +44,13 @@ void gated_delta_net_sycl(const float * q,
float * attn_data = dst;
float * state = dst + attn_score_elems;
const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
state += state_offset;
curr_state += state_offset;
// input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v.
// output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before.
const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v;
const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v;
const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output
state += state_out_offset;
curr_state += state_in_offset + col * S_v;
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
constexpr int warp_size = ggml_sycl_get_physical_warp_size() < S_v ? ggml_sycl_get_physical_warp_size() : S_v;
@ -55,9 +60,13 @@ void gated_delta_net_sycl(const float * q,
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = curr_state[col * S_v + i];
s_shard[r] = curr_state[i];
}
// slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots
// are written; earlier slots are left untouched (caller-owned).
const int shift = (int) n_tokens - K;
for (int t = 0; t < n_tokens; t++) {
const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1;
const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1;
@ -131,17 +140,32 @@ void gated_delta_net_sycl(const float * q,
}
attn_data += S_v * H;
}
// Write state back to global memory
if constexpr (keep_rs_t) {
const int target_slot = t - shift;
if (target_slot >= 0 && target_slot < K) {
float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset;
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
state[col * S_v + i] = s_shard[r];
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
curr_state[col * S_v + i] = s_shard[r];
}
}
}
}
if constexpr (!keep_rs_t) {
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
state[col * S_v + i] = s_shard[r];
}
}
}
template <bool KDA>
template <bool KDA, bool keep_rs_t>
static void launch_gated_delta_net(const float * q_d,
const float * k_d,
const float * v_d,
@ -165,6 +189,7 @@ static void launch_gated_delta_net(const float * q_d,
int64_t neqk1,
int64_t rq3,
float scale,
int K,
dpct::queue_ptr stream) {
//TODO: Add chunked kernel for even faster pre-fill
const int warp_size = ggml_sycl_info().devices[ggml_sycl_get_device()].warp_size;
@ -182,9 +207,9 @@ static void launch_gated_delta_net(const float * q_d,
constexpr int sv = 16;
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_delta_net_sycl<sv, KDA>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
gated_delta_net_sycl<sv, KDA, keep_rs_t>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
sb3, neqk1_magic, rq3_magic, scale);
sb3, neqk1_magic, rq3_magic, scale, K);
});
}
break;
@ -193,9 +218,9 @@ static void launch_gated_delta_net(const float * q_d,
constexpr int sv = 32;
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_delta_net_sycl<sv, KDA>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
gated_delta_net_sycl<sv, KDA, keep_rs_t>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
sb3, neqk1_magic, rq3_magic, scale);
sb3, neqk1_magic, rq3_magic, scale, K);
});
}
break;
@ -204,9 +229,9 @@ static void launch_gated_delta_net(const float * q_d,
constexpr int sv = 64;
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_delta_net_sycl<sv, KDA>(
gated_delta_net_sycl<sv, KDA, keep_rs_t>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
});
}
break;
@ -216,9 +241,9 @@ static void launch_gated_delta_net(const float * q_d,
constexpr int sv = 128;
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_delta_net_sycl<sv, KDA>(
gated_delta_net_sycl<sv, KDA, keep_rs_t>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
});
}
break;
@ -290,14 +315,30 @@ void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor *
dpct::queue_ptr stream = ctx.stream();
// state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
const int K = (int) src_state->ne[1];
const bool keep_rs = K > 1;
if (kda) {
launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, stream);
if (keep_rs) {
launch_gated_delta_net<true, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
} else {
launch_gated_delta_net<true, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
}
} else {
launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, stream);
if (keep_rs) {
launch_gated_delta_net<false, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
} else {
launch_gated_delta_net<false, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
}
}
}