mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
vulkan: reduce register usage in scalar FA, but perf may be slightly worse
This commit is contained in:
parent
a6c940bb79
commit
c747227a57
2 changed files with 27 additions and 15 deletions
|
@ -1921,14 +1921,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
};
|
};
|
||||||
|
|
||||||
#define CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, D) \
|
#define CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, D) \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1]); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1]); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1]); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1]); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \
|
||||||
|
|
||||||
#define CREATE_FA(TYPE, NAMELC, SCALAR, SUFFIX) \
|
#define CREATE_FA(TYPE, NAMELC, SCALAR, SUFFIX) \
|
||||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 64) \
|
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 64) \
|
||||||
|
|
|
@ -106,6 +106,9 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
|
||||||
|
|
||||||
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
||||||
|
|
||||||
|
shared float16_t masksh[Bc][Br];
|
||||||
|
shared vec4 Qf[Br][D / 4];
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||||
init_iq_shmem(gl_WorkGroupSize);
|
init_iq_shmem(gl_WorkGroupSize);
|
||||||
|
@ -164,11 +167,11 @@ void main() {
|
||||||
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
||||||
|
|
||||||
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||||
vec4 Qf[Br][D_per_thread / 4];
|
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
if (i * Br + r < N) {
|
if (i * Br + r < N) {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d * D_split + d_tid]) * p.scale;
|
Qf[r][d * D_split + d_tid] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d * D_split + d_tid]) * p.scale;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -218,7 +221,7 @@ void main() {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
Sf[r][c] += dot(Qf[r][d], K_Tf);
|
Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -241,13 +244,22 @@ void main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (p.mask != 0) {
|
if (p.mask != 0) {
|
||||||
|
|
||||||
|
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||||
|
uint32_t c = (idx + tid) % Bc;
|
||||||
|
uint32_t r = (idx + tid) / Bc;
|
||||||
|
masksh[c][r] = data_m[(i * Br + r) * m_stride + (j * Bc + c)];
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
|
||||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
float mvf = data_m[(i * Br + r) * m_stride + (j * Bc + c * cols_per_iter + col_tid)];
|
float mvf = float(masksh[c * cols_per_iter + col_tid][r]);
|
||||||
|
|
||||||
Sf[r][c] += slope[r]*mvf;
|
Sf[r][c] += slope[r]*mvf;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
barrier();
|
||||||
}
|
}
|
||||||
|
|
||||||
float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
|
float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
|
||||||
|
@ -363,7 +375,7 @@ void main() {
|
||||||
|
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
if (r < N) {
|
if (r < N) {
|
||||||
for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||||
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
||||||
}
|
}
|
||||||
|
@ -398,7 +410,7 @@ void main() {
|
||||||
if (p.gqa_ratio > 1) {
|
if (p.gqa_ratio > 1) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
if (r < N) {
|
if (r < N) {
|
||||||
for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||||
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
||||||
}
|
}
|
||||||
|
@ -408,7 +420,7 @@ void main() {
|
||||||
} else {
|
} else {
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
if (i * Br + r < N) {
|
if (i * Br + r < N) {
|
||||||
for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||||
data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue