mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-26 15:53:38 +00:00
SYCL : gated_delta_net K>1 (#23174)
* sycl_gated_delta_net K>1 * editor_config
This commit is contained in:
parent
8cc67efcd4
commit
56f16f235c
1 changed files with 66 additions and 25 deletions
|
|
@ -6,7 +6,7 @@
|
|||
#include <cmath>
|
||||
|
||||
|
||||
template <int S_v, bool KDA>
|
||||
template <int S_v, bool KDA, bool keep_rs_t>
|
||||
void gated_delta_net_sycl(const float * q,
|
||||
const float * k,
|
||||
const float * v,
|
||||
|
|
@ -28,7 +28,8 @@ void gated_delta_net_sycl(const float * q,
|
|||
int64_t sb3,
|
||||
const sycl::uint3 neqk1_magic,
|
||||
const sycl::uint3 rq3_magic,
|
||||
float scale) {
|
||||
float scale,
|
||||
int K) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const uint32_t h_idx = item_ct1.get_group(2);
|
||||
const uint32_t sequence = item_ct1.get_group(1);
|
||||
|
|
@ -43,9 +44,13 @@ void gated_delta_net_sycl(const float * q,
|
|||
float * attn_data = dst;
|
||||
float * state = dst + attn_score_elems;
|
||||
|
||||
const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
|
||||
state += state_offset;
|
||||
curr_state += state_offset;
|
||||
// input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v.
|
||||
// output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before.
|
||||
const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v;
|
||||
const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v;
|
||||
const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output
|
||||
state += state_out_offset;
|
||||
curr_state += state_in_offset + col * S_v;
|
||||
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
|
||||
|
||||
constexpr int warp_size = ggml_sycl_get_physical_warp_size() < S_v ? ggml_sycl_get_physical_warp_size() : S_v;
|
||||
|
|
@ -55,9 +60,13 @@ void gated_delta_net_sycl(const float * q,
|
|||
#pragma unroll
|
||||
for (int r = 0; r < rows_per_lane; r++) {
|
||||
const int i = r * warp_size + lane;
|
||||
s_shard[r] = curr_state[col * S_v + i];
|
||||
s_shard[r] = curr_state[i];
|
||||
}
|
||||
|
||||
// slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots
|
||||
// are written; earlier slots are left untouched (caller-owned).
|
||||
const int shift = (int) n_tokens - K;
|
||||
|
||||
for (int t = 0; t < n_tokens; t++) {
|
||||
const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1;
|
||||
const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1;
|
||||
|
|
@ -131,17 +140,32 @@ void gated_delta_net_sycl(const float * q,
|
|||
}
|
||||
|
||||
attn_data += S_v * H;
|
||||
}
|
||||
|
||||
|
||||
// Write state back to global memory
|
||||
if constexpr (keep_rs_t) {
|
||||
const int target_slot = t - shift;
|
||||
if (target_slot >= 0 && target_slot < K) {
|
||||
float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset;
|
||||
#pragma unroll
|
||||
for (int r = 0; r < rows_per_lane; r++) {
|
||||
const int i = r * warp_size + lane;
|
||||
state[col * S_v + i] = s_shard[r];
|
||||
for (int r = 0; r < rows_per_lane; r++) {
|
||||
const int i = r * warp_size + lane;
|
||||
curr_state[col * S_v + i] = s_shard[r];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (!keep_rs_t) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < rows_per_lane; r++) {
|
||||
const int i = r * warp_size + lane;
|
||||
state[col * S_v + i] = s_shard[r];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool KDA>
|
||||
template <bool KDA, bool keep_rs_t>
|
||||
static void launch_gated_delta_net(const float * q_d,
|
||||
const float * k_d,
|
||||
const float * v_d,
|
||||
|
|
@ -165,6 +189,7 @@ static void launch_gated_delta_net(const float * q_d,
|
|||
int64_t neqk1,
|
||||
int64_t rq3,
|
||||
float scale,
|
||||
int K,
|
||||
dpct::queue_ptr stream) {
|
||||
//TODO: Add chunked kernel for even faster pre-fill
|
||||
const int warp_size = ggml_sycl_info().devices[ggml_sycl_get_device()].warp_size;
|
||||
|
|
@ -182,9 +207,9 @@ static void launch_gated_delta_net(const float * q_d,
|
|||
constexpr int sv = 16;
|
||||
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_delta_net_sycl<sv, KDA>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
|
||||
gated_delta_net_sycl<sv, KDA, keep_rs_t>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
|
||||
n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
|
||||
sb3, neqk1_magic, rq3_magic, scale);
|
||||
sb3, neqk1_magic, rq3_magic, scale, K);
|
||||
});
|
||||
}
|
||||
break;
|
||||
|
|
@ -193,9 +218,9 @@ static void launch_gated_delta_net(const float * q_d,
|
|||
constexpr int sv = 32;
|
||||
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_delta_net_sycl<sv, KDA>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
|
||||
gated_delta_net_sycl<sv, KDA, keep_rs_t>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
|
||||
n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
|
||||
sb3, neqk1_magic, rq3_magic, scale);
|
||||
sb3, neqk1_magic, rq3_magic, scale, K);
|
||||
});
|
||||
}
|
||||
break;
|
||||
|
|
@ -204,9 +229,9 @@ static void launch_gated_delta_net(const float * q_d,
|
|||
constexpr int sv = 64;
|
||||
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_delta_net_sycl<sv, KDA>(
|
||||
gated_delta_net_sycl<sv, KDA, keep_rs_t>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
|
||||
sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
|
||||
sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
|
||||
});
|
||||
}
|
||||
break;
|
||||
|
|
@ -216,9 +241,9 @@ static void launch_gated_delta_net(const float * q_d,
|
|||
constexpr int sv = 128;
|
||||
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_delta_net_sycl<sv, KDA>(
|
||||
gated_delta_net_sycl<sv, KDA, keep_rs_t>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
|
||||
sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
|
||||
sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
|
||||
});
|
||||
}
|
||||
break;
|
||||
|
|
@ -290,14 +315,30 @@ void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor *
|
|||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
// state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
|
||||
const int K = (int) src_state->ne[1];
|
||||
const bool keep_rs = K > 1;
|
||||
|
||||
if (kda) {
|
||||
launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
|
||||
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, stream);
|
||||
if (keep_rs) {
|
||||
launch_gated_delta_net<true, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
|
||||
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
|
||||
} else {
|
||||
launch_gated_delta_net<true, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
|
||||
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
|
||||
}
|
||||
} else {
|
||||
launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
|
||||
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, stream);
|
||||
if (keep_rs) {
|
||||
launch_gated_delta_net<false, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
|
||||
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
|
||||
} else {
|
||||
launch_gated_delta_net<false, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
|
||||
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue