vulkan: for scalar FA, select between 1 and 8 rows

This commit is contained in:
Jeff Bolz 2025-05-08 14:34:59 -05:00
parent 00784e3d34
commit 615958f42c
2 changed files with 36 additions and 22 deletions

View file

@ -1590,7 +1590,8 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
// number of rows/cols for flash attention shader
static constexpr uint32_t flash_attention_num_small_rows = 32;
static constexpr uint32_t scalar_flash_attention_num_small_rows = 8;
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
static uint32_t get_fa_num_small_rows(bool scalar) {
return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows;
@ -1599,8 +1600,16 @@ static uint32_t get_fa_num_small_rows(bool scalar) {
static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
GGML_UNUSED(clamp);
if (scalar) {
if (small_rows) {
return {scalar_flash_attention_num_small_rows, 64};
} else {
return {scalar_flash_attention_num_large_rows, 32};
}
}
// small rows, large cols
if (small_rows || scalar) {
if (small_rows) {
return {get_fa_num_small_rows(scalar), 32};
}
@ -5729,8 +5738,29 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
assert(q->type == GGML_TYPE_F32);
assert(k->type == v->type);
vk_pipeline *pipelines;
bool scalar = !ctx->device->coopmat2;
uint32_t gqa_ratio = 1;
uint32_t qk_ratio = neq2 / nek2;
uint32_t workgroups_x = (uint32_t)neq1;
uint32_t workgroups_y = (uint32_t)neq2;
uint32_t workgroups_z = (uint32_t)neq3;
// For scalar FA, we can use the "large" size to accommodate qga.
// For coopmat FA, we always use the small size (which is still pretty large for gqa).
const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false);
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
// and change addressing calculations to index Q's dimension 2.
gqa_ratio = qk_ratio;
N = gqa_ratio;
workgroups_y /= N;
}
vk_pipeline *pipelines;
// XXX TODO other backends may be changing accumulator precision to default to f32 soon
bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32;
bool small_rows = N <= get_fa_num_small_rows(scalar);
@ -5776,24 +5806,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
vk_pipeline pipeline = pipelines[aligned];
assert(pipeline);
uint32_t gqa_ratio = 1;
uint32_t qk_ratio = neq2 / nek2;
uint32_t workgroups_x = (uint32_t)neq1;
uint32_t workgroups_y = (uint32_t)neq2;
uint32_t workgroups_z = (uint32_t)neq3;
const uint32_t max_gqa = get_fa_num_small_rows(scalar);
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
// and change addressing calculations to index Q's dimension 2.
gqa_ratio = qk_ratio;
N = gqa_ratio;
workgroups_y /= N;
}
uint32_t split_kv = KV;
uint32_t split_k = 1;

View file

@ -295,7 +295,9 @@ void main() {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
masksh[c][r] = data_m[(i * Br + r) * m_stride + (j * Bc + c)];
if (idx + tid < Bc * Br) {
masksh[c][r] = data_m[(i * Br + r) * m_stride + (j * Bc + c)];
}
}
barrier();