mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 17:44:38 +00:00
* vulkan: allow FA split_k with smaller KV values * vulkan: spread split_k_reduce work across more threads k_num can get rather large. Use the whole workgroup to reduce the M/L values. Launch a thread for each element in the HSV dimension of the output. Helps a lot for large HSV (like deepseek).
91 lines
2.4 KiB
Text
91 lines
2.4 KiB
Text
#version 450
|
|
|
|
#extension GL_EXT_control_flow_attributes : enable
|
|
|
|
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
|
|
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
layout (binding = 0) readonly buffer A {float data_a[];};
|
|
layout (binding = 1) writeonly buffer D {float data_d[];};
|
|
|
|
layout (push_constant) uniform parameter {
|
|
uint D;
|
|
uint N;
|
|
uint ne3;
|
|
uint k_num;
|
|
} p;
|
|
|
|
shared float tmpsh[BLOCK_SIZE];
|
|
|
|
void main() {
|
|
// Each workgroup handles a row
|
|
const uint n = gl_WorkGroupID.x;
|
|
const uint tid = gl_LocalInvocationID.x;
|
|
const uint iq3 = gl_WorkGroupID.z;
|
|
|
|
uint D = p.D;
|
|
uint N = p.N;
|
|
uint k_num = p.k_num;
|
|
|
|
uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
|
|
uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
|
|
uint lm_stride = N * 2;
|
|
|
|
// Compute the max m value for the row
|
|
float m_max = -1.0/0.0;
|
|
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
|
|
float m = data_a[m_offset + (k + tid) * lm_stride];
|
|
m_max = max(m_max, m);
|
|
}
|
|
|
|
// reduce across the workgroup
|
|
tmpsh[tid] = m_max;
|
|
barrier();
|
|
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
|
if (tid < s) {
|
|
m_max = max(m_max, tmpsh[tid + s]);
|
|
tmpsh[tid] = m_max;
|
|
}
|
|
barrier();
|
|
}
|
|
m_max = tmpsh[0];
|
|
|
|
barrier();
|
|
|
|
// Compute L based on m_max
|
|
float L = 0;
|
|
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
|
|
float l = data_a[l_offset + (k + tid) * lm_stride];
|
|
float m = data_a[m_offset + (k + tid) * lm_stride];
|
|
L += exp(m - m_max) * l;
|
|
}
|
|
|
|
// reduce across the workgroup
|
|
tmpsh[tid] = L;
|
|
barrier();
|
|
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
|
if (tid < s) {
|
|
L += tmpsh[tid + s];
|
|
tmpsh[tid] = L;
|
|
}
|
|
barrier();
|
|
}
|
|
L = tmpsh[0];
|
|
|
|
L = 1.0 / L;
|
|
|
|
// D dimension is split across workgroups in the y dimension
|
|
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
|
|
// Scale and sum the O contributions based on m_max and store the result to memory
|
|
if (d < D) {
|
|
float O = 0.0;
|
|
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
|
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
|
|
float m = data_a[m_offset + k * lm_stride];
|
|
O += exp(m - m_max) * data_a[o_offset];
|
|
}
|
|
O *= L;
|
|
data_d[iq3 * D * N + D * n + d] = O;
|
|
}
|
|
}
|