mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-04-28 11:40:43 +00:00
* vulkan: change gated_delta_net to shard a column across a subgroup This is based on https://github.com/ggml-org/llama.cpp/pull/20391, I used an LLM to port the CUDA code to Vulkan, and guided to it to make various fixes to work with Vulkan (e.g. handling different subgroup sizes, unknown mapping of subgroup to invocation id, using subgroupAdd optionally, etc.). This fixes a perf regression from the transposing of the values in memory (!20443). * vulkan: Spread columns across fewer lanes to reduce the number of workgroups
169 lines
5.8 KiB
Text
169 lines
5.8 KiB
Text
#version 450
|
|
|
|
#extension GL_EXT_control_flow_attributes : require
|
|
#extension GL_KHR_shader_subgroup_basic : enable
|
|
#if USE_SUBGROUP_CLUSTERED
|
|
#extension GL_KHR_shader_subgroup_clustered : enable
|
|
#endif
|
|
#if USE_SUBGROUP_ADD
|
|
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
|
#endif
|
|
|
|
// Caller guarantees valid spec constants: S_V % COLS_PER_WG == 0 and S_V % LANES_PER_COLUMN == 0,
|
|
// so no bounds checking is needed.
|
|
layout(constant_id = 0) const uint S_V = 128;
|
|
layout(constant_id = 1) const uint KDA = 0;
|
|
layout(constant_id = 2) const uint SUBGROUP_SIZE = 32;
|
|
layout(constant_id = 3) const uint LANES_PER_COLUMN = 32;
|
|
|
|
const uint COLS_PER_WG = SUBGROUP_SIZE / LANES_PER_COLUMN;
|
|
const uint ROWS_PER_LANE = S_V / LANES_PER_COLUMN;
|
|
|
|
layout(local_size_x_id = 2, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
layout(push_constant) uniform Parameters {
|
|
uint H;
|
|
uint n_tokens;
|
|
uint n_seqs;
|
|
uint s_off;
|
|
uint sq1, sq2, sq3;
|
|
uint sv1, sv2, sv3;
|
|
uint sb1, sb2, sb3;
|
|
uint neq1, rq3;
|
|
float scale;
|
|
};
|
|
|
|
layout(binding = 0) readonly buffer QBuf { FLOAT_TYPE data_q[]; };
|
|
layout(binding = 1) readonly buffer KBuf { FLOAT_TYPE data_k[]; };
|
|
layout(binding = 2) readonly buffer VBuf { FLOAT_TYPE data_v[]; };
|
|
layout(binding = 3) readonly buffer GBuf { FLOAT_TYPE data_g[]; };
|
|
layout(binding = 4) readonly buffer BetaBuf { FLOAT_TYPE data_beta[]; };
|
|
layout(binding = 5) readonly buffer StateBuf { FLOAT_TYPE data_state[]; };
|
|
layout(binding = 6) buffer DstBuf { FLOAT_TYPE data_dst[]; };
|
|
|
|
#if !USE_SUBGROUP_ADD && !USE_SUBGROUP_CLUSTERED
|
|
shared FLOAT_TYPE temp[SUBGROUP_SIZE];
|
|
|
|
// This does a reduction across groups of LANES_PER_COLUMN
|
|
FLOAT_TYPE reduce_add_shmem(FLOAT_TYPE partial) {
|
|
const uint lane = gl_SubgroupInvocationID;
|
|
temp[lane] = partial;
|
|
barrier();
|
|
[[unroll]] for (uint s = LANES_PER_COLUMN / 2u; s > 0; s >>= 1u) {
|
|
FLOAT_TYPE other = temp[lane ^ s];
|
|
barrier();
|
|
temp[lane] += other;
|
|
barrier();
|
|
}
|
|
const FLOAT_TYPE result = temp[lane];
|
|
barrier();
|
|
return result;
|
|
}
|
|
#endif
|
|
|
|
// clusterSize for subgroupClusteredAdd must be a compile-time constant; branch on spec constant
|
|
FLOAT_TYPE reduce_partial(FLOAT_TYPE partial) {
|
|
switch (LANES_PER_COLUMN) {
|
|
case 1u:
|
|
return partial;
|
|
#if USE_SUBGROUP_CLUSTERED
|
|
// Workaround for GLSL requiring a literal constant for the cluster size.
|
|
// The branches should all fold away.
|
|
case 2u:
|
|
return subgroupClusteredAdd(partial, 2u);
|
|
case 4u:
|
|
return subgroupClusteredAdd(partial, 4u);
|
|
case 8u:
|
|
return subgroupClusteredAdd(partial, 8u);
|
|
case 16u:
|
|
return subgroupClusteredAdd(partial, 16u);
|
|
case 32u:
|
|
return subgroupClusteredAdd(partial, 32u);
|
|
case 64u:
|
|
return subgroupClusteredAdd(partial, 64u);
|
|
#endif
|
|
default:
|
|
#if USE_SUBGROUP_ADD
|
|
return subgroupAdd(partial);
|
|
#else
|
|
return reduce_add_shmem(partial);
|
|
#endif
|
|
}
|
|
}
|
|
|
|
void main() {
|
|
const uint head_id = gl_WorkGroupID.x;
|
|
const uint seq_id = gl_WorkGroupID.y;
|
|
const uint lane = gl_SubgroupInvocationID % LANES_PER_COLUMN;
|
|
const uint col = gl_WorkGroupID.z * COLS_PER_WG + (gl_SubgroupInvocationID / LANES_PER_COLUMN);
|
|
|
|
const uint iq1 = head_id % neq1;
|
|
const uint iq3 = seq_id / rq3;
|
|
|
|
const uint state_size = S_V * S_V;
|
|
const uint state_base = (seq_id * H + head_id) * state_size;
|
|
|
|
FLOAT_TYPE s_shard[ROWS_PER_LANE];
|
|
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
|
s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]);
|
|
}
|
|
|
|
uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
|
|
|
|
for (uint t = 0; t < n_tokens; t++) {
|
|
const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1;
|
|
const uint k_off = q_off;
|
|
const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1;
|
|
const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1;
|
|
const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]);
|
|
|
|
FLOAT_TYPE k_reg[ROWS_PER_LANE];
|
|
FLOAT_TYPE q_reg[ROWS_PER_LANE];
|
|
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
|
const uint i = r * LANES_PER_COLUMN + lane;
|
|
k_reg[r] = FLOAT_TYPE(data_k[k_off + i]);
|
|
q_reg[r] = FLOAT_TYPE(data_q[q_off + i]);
|
|
}
|
|
|
|
FLOAT_TYPE g_exp[ROWS_PER_LANE];
|
|
if (KDA == 0) {
|
|
const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off]));
|
|
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
|
g_exp[r] = g_val;
|
|
}
|
|
} else {
|
|
const uint g_base = gb_off * S_V;
|
|
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
|
const uint i = r * LANES_PER_COLUMN + lane;
|
|
g_exp[r] = exp(FLOAT_TYPE(data_g[g_base + i]));
|
|
}
|
|
}
|
|
|
|
const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);
|
|
|
|
FLOAT_TYPE kv_shard = 0.0;
|
|
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
|
kv_shard += g_exp[r] * s_shard[r] * k_reg[r];
|
|
}
|
|
FLOAT_TYPE kv_col = reduce_partial(kv_shard);
|
|
|
|
FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;
|
|
|
|
FLOAT_TYPE attn_partial = 0.0;
|
|
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
|
s_shard[r] = g_exp[r] * s_shard[r] + k_reg[r] * delta_col;
|
|
attn_partial += s_shard[r] * q_reg[r];
|
|
}
|
|
FLOAT_TYPE attn_col = reduce_partial(attn_partial);
|
|
|
|
if (lane == 0) {
|
|
data_dst[attn_off + col] = attn_col * scale;
|
|
}
|
|
|
|
attn_off += S_V * H;
|
|
}
|
|
|
|
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
|
data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r];
|
|
}
|
|
}
|