mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-12 09:59:41 +00:00
vulkan: optimize flash attention split_k_reduce (#14554)
* vulkan: allow FA split_k with smaller KV values * vulkan: spread split_k_reduce work across more threads k_num can get rather large. Use the whole workgroup to reduce the M/L values. Launch a thread for each element in the HSV dimension of the output. Helps a lot for large HSV (like deepseek).
This commit is contained in:
parent
699f4392a3
commit
6efcd65945
2 changed files with 42 additions and 12 deletions
|
@ -2,9 +2,9 @@
|
|||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
#define BLOCK_SIZE 32
|
||||
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {float data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {float data_d[];};
|
||||
|
@ -16,6 +16,8 @@ layout (push_constant) uniform parameter {
|
|||
uint k_num;
|
||||
} p;
|
||||
|
||||
shared float tmpsh[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
// Each workgroup handles a row
|
||||
const uint n = gl_WorkGroupID.x;
|
||||
|
@ -32,23 +34,51 @@ void main() {
|
|||
|
||||
// Compute the max m value for the row
|
||||
float m_max = -1.0/0.0;
|
||||
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
||||
float m = data_a[m_offset + k * lm_stride];
|
||||
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
|
||||
float m = data_a[m_offset + (k + tid) * lm_stride];
|
||||
m_max = max(m_max, m);
|
||||
}
|
||||
|
||||
// reduce across the workgroup
|
||||
tmpsh[tid] = m_max;
|
||||
barrier();
|
||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
m_max = max(m_max, tmpsh[tid + s]);
|
||||
tmpsh[tid] = m_max;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
m_max = tmpsh[0];
|
||||
|
||||
barrier();
|
||||
|
||||
// Compute L based on m_max
|
||||
float L = 0;
|
||||
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
||||
float l = data_a[l_offset + k * lm_stride];
|
||||
float m = data_a[m_offset + k * lm_stride];
|
||||
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
|
||||
float l = data_a[l_offset + (k + tid) * lm_stride];
|
||||
float m = data_a[m_offset + (k + tid) * lm_stride];
|
||||
L += exp(m - m_max) * l;
|
||||
}
|
||||
|
||||
// reduce across the workgroup
|
||||
tmpsh[tid] = L;
|
||||
barrier();
|
||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
L += tmpsh[tid + s];
|
||||
tmpsh[tid] = L;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
L = tmpsh[0];
|
||||
|
||||
L = 1.0 / L;
|
||||
|
||||
// D dimension is split across workgroups in the y dimension
|
||||
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
|
||||
// Scale and sum the O contributions based on m_max and store the result to memory
|
||||
for (uint d = tid; d < D; d += BLOCK_SIZE) {
|
||||
if (d < D) {
|
||||
float O = 0.0;
|
||||
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
||||
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue