mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-12 09:59:41 +00:00
vulkan: support softmax/FA batch and broadcast (#14449)
This commit is contained in:
parent
ec68e84c32
commit
8875523eb3
7 changed files with 80 additions and 44 deletions
|
@ -12,6 +12,7 @@ layout (binding = 1) writeonly buffer D {float data_d[];};
|
|||
layout (push_constant) uniform parameter {
|
||||
uint D;
|
||||
uint N;
|
||||
uint ne3;
|
||||
uint k_num;
|
||||
} p;
|
||||
|
||||
|
@ -19,13 +20,14 @@ void main() {
|
|||
// Each workgroup handles a row
|
||||
const uint n = gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint iq3 = gl_WorkGroupID.z;
|
||||
|
||||
uint D = p.D;
|
||||
uint N = p.N;
|
||||
uint k_num = p.k_num;
|
||||
|
||||
uint l_offset = D * N * k_num + n;
|
||||
uint m_offset = D * N * k_num + N + n;
|
||||
uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
|
||||
uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
|
||||
uint lm_stride = N * 2;
|
||||
|
||||
// Compute the max m value for the row
|
||||
|
@ -49,11 +51,11 @@ void main() {
|
|||
for (uint d = tid; d < D; d += BLOCK_SIZE) {
|
||||
float O = 0.0;
|
||||
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
||||
uint o_offset = D * N * k + D * n + d;
|
||||
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
|
||||
float m = data_a[m_offset + k * lm_stride];
|
||||
O += exp(m - m_max) * data_a[o_offset];
|
||||
}
|
||||
O *= L;
|
||||
data_d[D * n + d] = O;
|
||||
data_d[iq3 * D * N + D * n + d] = O;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue