mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 17:44:38 +00:00
vulkan: optimize flash attention split_k_reduce (#14554)
* vulkan: allow FA split_k with smaller KV values * vulkan: spread split_k_reduce work across more threads k_num can get rather large. Use the whole workgroup to reduce the M/L values. Launch a thread for each element in the HSV dimension of the output. Helps a lot for large HSV (like deepseek).
This commit is contained in:
parent
699f4392a3
commit
6efcd65945
2 changed files with 42 additions and 12 deletions
|
@ -2706,7 +2706,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
|
||||
|
||||
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
|
||||
|
@ -6252,13 +6252,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
|
||||
|
||||
// Try to use split_k when KV is large enough to be worth the overhead
|
||||
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
|
||||
if (workgroups_x == 1 && shader_core_count > 0) {
|
||||
// Try to run two workgroups per SM.
|
||||
split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
|
||||
if (split_k > 1) {
|
||||
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
||||
// of "align", so recompute split_k based on that.
|
||||
split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
|
||||
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align);
|
||||
split_k = CEIL_DIV(KV, split_kv);
|
||||
workgroups_x = split_k;
|
||||
}
|
||||
|
@ -6392,7 +6392,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
||||
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
||||
},
|
||||
pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 });
|
||||
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
|
||||
} else {
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{
|
||||
|
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
#define BLOCK_SIZE 32
|
||||
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {float data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {float data_d[];};
|
||||
|
@ -16,6 +16,8 @@ layout (push_constant) uniform parameter {
|
|||
uint k_num;
|
||||
} p;
|
||||
|
||||
shared float tmpsh[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
// Each workgroup handles a row
|
||||
const uint n = gl_WorkGroupID.x;
|
||||
|
@ -32,23 +34,51 @@ void main() {
|
|||
|
||||
// Compute the max m value for the row
|
||||
float m_max = -1.0/0.0;
|
||||
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
||||
float m = data_a[m_offset + k * lm_stride];
|
||||
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
|
||||
float m = data_a[m_offset + (k + tid) * lm_stride];
|
||||
m_max = max(m_max, m);
|
||||
}
|
||||
|
||||
// reduce across the workgroup
|
||||
tmpsh[tid] = m_max;
|
||||
barrier();
|
||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
m_max = max(m_max, tmpsh[tid + s]);
|
||||
tmpsh[tid] = m_max;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
m_max = tmpsh[0];
|
||||
|
||||
barrier();
|
||||
|
||||
// Compute L based on m_max
|
||||
float L = 0;
|
||||
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
||||
float l = data_a[l_offset + k * lm_stride];
|
||||
float m = data_a[m_offset + k * lm_stride];
|
||||
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
|
||||
float l = data_a[l_offset + (k + tid) * lm_stride];
|
||||
float m = data_a[m_offset + (k + tid) * lm_stride];
|
||||
L += exp(m - m_max) * l;
|
||||
}
|
||||
|
||||
// reduce across the workgroup
|
||||
tmpsh[tid] = L;
|
||||
barrier();
|
||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
L += tmpsh[tid + s];
|
||||
tmpsh[tid] = L;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
L = tmpsh[0];
|
||||
|
||||
L = 1.0 / L;
|
||||
|
||||
// D dimension is split across workgroups in the y dimension
|
||||
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
|
||||
// Scale and sum the O contributions based on m_max and store the result to memory
|
||||
for (uint d = tid; d < D; d += BLOCK_SIZE) {
|
||||
if (d < D) {
|
||||
float O = 0.0;
|
||||
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
||||
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue