vulkan: support softmax/FA batch and broadcast (#14449)

This commit is contained in:
Jeff Bolz 2025-07-01 03:32:56 -05:00 committed by Georgi Gerganov
parent ec68e84c32
commit 8875523eb3
7 changed files with 80 additions and 44 deletions

View file

@ -12,6 +12,7 @@ layout (binding = 1) writeonly buffer D {float data_d[];};
layout (push_constant) uniform parameter {
uint D;
uint N;
uint ne3;
uint k_num;
} p;
@ -19,13 +20,14 @@ void main() {
// Each workgroup handles a row
const uint n = gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint iq3 = gl_WorkGroupID.z;
uint D = p.D;
uint N = p.N;
uint k_num = p.k_num;
uint l_offset = D * N * k_num + n;
uint m_offset = D * N * k_num + N + n;
uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
uint lm_stride = N * 2;
// Compute the max m value for the row
@ -49,11 +51,11 @@ void main() {
for (uint d = tid; d < D; d += BLOCK_SIZE) {
float O = 0.0;
[[unroll]] for (uint k = 0; k < k_num; ++k) {
uint o_offset = D * N * k + D * n + d;
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
float m = data_a[m_offset + k * lm_stride];
O += exp(m - m_max) * data_a[o_offset];
}
O *= L;
data_d[D * n + d] = O;
data_d[iq3 * D * N + D * n + d] = O;
}
}