vulkan: support q4_0/q8_0 KV in scalar FA

This commit is contained in:
Jeff Bolz 2025-05-07 23:53:38 -05:00
parent 989bfb18fc
commit e66094276b
3 changed files with 66 additions and 8 deletions

View file

@ -1939,6 +1939,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 256) CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 256)
CREATE_FA(GGML_TYPE_F16, f16, true, ) CREATE_FA(GGML_TYPE_F16, f16, true, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, true, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, true, )
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) { if (device->coopmat2) {
CREATE_FA(GGML_TYPE_F16, f16, false, _cm2) CREATE_FA(GGML_TYPE_F16, f16, false, _cm2)
@ -9603,10 +9605,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
switch (op->src[1]->type) { switch (op->src[1]->type) {
case GGML_TYPE_F16: case GGML_TYPE_F16:
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
// supported in scalar and coopmat2 paths
break;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
// K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
//case GGML_TYPE_Q2_K: //case GGML_TYPE_Q2_K:
//case GGML_TYPE_Q3_K: //case GGML_TYPE_Q3_K:
@ -9622,13 +9626,14 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
//case GGML_TYPE_IQ3_S: //case GGML_TYPE_IQ3_S:
//case GGML_TYPE_IQ4_XS: //case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
// currently supported only in coopmat2 path
if (!coopmat2) {
return false;
}
break; break;
default: default:
return false; return false;
} }
if (!coopmat2 && op->src[1]->type != GGML_TYPE_F16) {
return false;
}
if (!coopmat2 && !device->subgroup_shuffle) { if (!coopmat2 && !device->subgroup_shuffle) {
// scalar FA uses subgroupShuffle // scalar FA uses subgroupShuffle
return false; return false;

View file

@ -72,6 +72,36 @@ layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
layout (binding = 3) readonly buffer M {float16_t data_m[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];};
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16)
#define BINDING_IDX_K 0
#define BINDING_IDX_V 1
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
#endif
#if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
}
#endif
#if defined(DATA_A_Q8_0)
#define BLOCK_BYTE_SIZE 34
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
// Store the output when doing grouped query attention. // Store the output when doing grouped query attention.
@ -208,6 +238,14 @@ void main() {
} }
} }
#if BLOCK_SIZE > 1
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
#else
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
[[dont_unroll]] [[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) { for (uint32_t j = start_j; j < end_j; ++j) {
@ -218,11 +256,17 @@ void main() {
} }
} }
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
#endif
[[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf); Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf);
} }
@ -297,11 +341,16 @@ void main() {
} }
} }
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
#endif
[[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] += Pf[r][c] * Vf; Of[r][d] += Pf[r][c] * Vf;
} }

View file

@ -444,7 +444,11 @@ void process_shaders() {
if (tname == "f16") { if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc); merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc);
} // quants not supported yet } else if (tname == "q4_0" || tname == "q8_0") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
}
} }
} }