mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-14 10:59:41 +00:00
vulkan: use vector loads in scalar flash attention shader
This commit is contained in:
parent
3a8d954e0c
commit
876e6617a7
2 changed files with 45 additions and 32 deletions
|
@ -1911,7 +1911,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
auto rows_cols = fa_rows_cols(scalar, D, clamp, type, small_rows);
|
auto rows_cols = fa_rows_cols(scalar, D, clamp, type, small_rows);
|
||||||
|
|
||||||
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
||||||
const uint32_t D_split = std::min(device->subgroup_size, 16u);
|
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
||||||
|
const uint32_t D_lsb = D ^ (D & (D-1));
|
||||||
|
uint32_t D_split = std::min(std::min(device->subgroup_size, 16u), D_lsb / 4);
|
||||||
|
|
||||||
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
||||||
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
|
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
|
||||||
|
|
|
@ -64,8 +64,11 @@ layout (push_constant) uniform parameter {
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer Q {float data_q[];};
|
layout (binding = 0) readonly buffer Q {float data_q[];};
|
||||||
|
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
|
||||||
layout (binding = 1) readonly buffer K {float16_t data_k[];};
|
layout (binding = 1) readonly buffer K {float16_t data_k[];};
|
||||||
|
layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
|
||||||
layout (binding = 2) readonly buffer V {float16_t data_v[];};
|
layout (binding = 2) readonly buffer V {float16_t data_v[];};
|
||||||
|
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
|
||||||
layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
||||||
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
||||||
|
|
||||||
|
@ -161,19 +164,19 @@ void main() {
|
||||||
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
||||||
|
|
||||||
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||||
float Qf[Br][D_per_thread];
|
vec4 Qf[Br][D_per_thread / 4];
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
if (i * Br + r < N) {
|
if (i * Br + r < N) {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
Qf[r][d] = float(data_q[q_offset + (i * Br + r) * q_stride + d * D_split + d_tid]) * p.scale;
|
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d * D_split + d_tid]) * p.scale;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
float Of[Br][D_per_thread];
|
vec4 Of[Br][D_per_thread / 4];
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
Of[r][d] = 0.0;
|
Of[r][d] = vec4(0.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -212,10 +215,10 @@ void main() {
|
||||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
||||||
|
|
||||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
float K_Tf = float(data_k[k_offset + (j * Bc + c * cols_per_iter + col_tid) * k_stride + d * D_split + d_tid]);
|
vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
Sf[r][c] += Qf[r][d] * K_Tf;
|
Sf[r][c] += dot(Qf[r][d], K_Tf);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -275,21 +278,21 @@ void main() {
|
||||||
|
|
||||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
||||||
|
|
||||||
float PVf[Br][D_per_thread];
|
vec4 PVf[Br][D_per_thread / 4];
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
PVf[r][d] = 0.0;
|
PVf[r][d] = vec4(0.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
float Vf = float(data_v[v_offset + (j * Bc + c * cols_per_iter + col_tid) * v_stride + d * D_split + d_tid]);
|
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
PVf[r][d] += Pf[r][c] * Vf;
|
PVf[r][d] += Pf[r][c] * Vf;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
Of[r][d] = eMf[r] * Of[r][d] + PVf[r][d];
|
Of[r][d] = eMf[r] * Of[r][d] + PVf[r][d];
|
||||||
}
|
}
|
||||||
|
@ -337,21 +340,23 @@ void main() {
|
||||||
Lf[r] = tmpsh[d_tid];
|
Lf[r] = tmpsh[d_tid];
|
||||||
barrier();
|
barrier();
|
||||||
|
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
|
|
||||||
Of[r][d] = eMf * Of[r][d];
|
Of[r][d] = eMf * Of[r][d];
|
||||||
tmpsh[tid] = Of[r][d];
|
[[unroll]] for (uint32_t c = 0; c < 4; ++c) {
|
||||||
|
tmpsh[tid] = Of[r][d][c];
|
||||||
|
|
||||||
barrier();
|
barrier();
|
||||||
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
|
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
|
||||||
if (tid < s) {
|
if (tid < s) {
|
||||||
Of[r][d] += tmpsh[tid + s];
|
Of[r][d][c] += tmpsh[tid + s];
|
||||||
tmpsh[tid] = Of[r][d];
|
tmpsh[tid] = Of[r][d][c];
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
}
|
}
|
||||||
|
Of[r][d][c] = tmpsh[d_tid];
|
||||||
barrier();
|
barrier();
|
||||||
}
|
}
|
||||||
Of[r][d] = tmpsh[d_tid];
|
|
||||||
barrier();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -363,8 +368,10 @@ void main() {
|
||||||
|
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
if (r < N) {
|
if (r < N) {
|
||||||
for (uint32_t d = 0; d < D_per_thread; ++d) {
|
for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
perElemOpGqaStore(r, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||||
|
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -385,7 +392,7 @@ void main() {
|
||||||
Lfrcp[r] = 1.0 / Lf[r];
|
Lfrcp[r] = 1.0 / Lf[r];
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
Of[r][d] *= Lfrcp[r];
|
Of[r][d] *= Lfrcp[r];
|
||||||
}
|
}
|
||||||
|
@ -396,16 +403,20 @@ void main() {
|
||||||
if (p.gqa_ratio > 1) {
|
if (p.gqa_ratio > 1) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
if (r < N) {
|
if (r < N) {
|
||||||
for (uint32_t d = 0; d < D_per_thread; ++d) {
|
for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
perElemOpGqaStore(r, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||||
|
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
if (i * Br + r < N) {
|
if (i * Br + r < N) {
|
||||||
for (uint32_t d = 0; d < D_per_thread; ++d) {
|
for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + d * D_split + d_tid] = D_TYPE(Of[r][d]);
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||||
|
data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue