vulkan: Support asymmetric FA in scalar/mmq/coopmat1 paths (#22589)

This commit is contained in:
Jeff Bolz 2026-05-11 05:49:03 -05:00 committed by GitHub
parent 8cef8201a1
commit dd9280a664
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 632 additions and 680 deletions

View file

@ -855,7 +855,7 @@ struct vk_device_struct {
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16;
std::map<std::pair<uint32_t, uint32_t>, vk_pipeline> pipeline_fa_mask_opt;
@ -2933,10 +2933,10 @@ struct vk_fa_tuning_params {
}
};
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type);
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type);
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
vk_fa_tuning_params result{};
result.path = FA_SCALAR;
@ -2988,7 +2988,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device,
result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, kv_type)) {
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, k_type, v_type)) {
result.block_rows /= 2;
}
@ -3011,10 +3011,11 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device,
return result;
}
static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
GGML_UNUSED(n_rows);
GGML_UNUSED(n_kv);
GGML_UNUSED(kv_type);
GGML_UNUSED(k_type);
GGML_UNUSED(v_type);
GGML_UNUSED(f32acc);
vk_fa_tuning_params result{};
@ -3070,12 +3071,6 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device
}
static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
// Mixed K/V is only implemented on the coopmat2 (flash_attn_cm2) path; never use scalar/cm1.
if (k_type != v_type) {
GGML_ASSERT(device->coopmat2);
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
}
FaCodePath path = device->coopmat2 ? FA_COOPMAT2 :
device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
@ -3087,7 +3082,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
if (path == FA_COOPMAT1) {
bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) ||
(!f32acc && device->coopmat_support_16x16x16_f16acc);
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc);
if (!shape_ok || !shmem_ok) {
@ -3107,9 +3102,9 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
switch (path) {
case FA_SCALAR:
return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
case FA_COOPMAT1:
return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
case FA_COOPMAT2:
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
default:
@ -3279,6 +3274,20 @@ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_dev
return 0; // If no matching configuration is found
}
// Whether scalar flash attention will use the MMQ path for the given k_type.
static bool ggml_vk_fa_scalar_uses_mmq(const vk_device& device, ggml_type k_type) {
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
return device->integer_dot_product && device->subgroup_clustered &&
(k_type == GGML_TYPE_Q4_0 || k_type == GGML_TYPE_Q4_1 ||
k_type == GGML_TYPE_Q5_0 || k_type == GGML_TYPE_Q5_1 ||
k_type == GGML_TYPE_Q8_0);
#else
GGML_UNUSED(device);
GGML_UNUSED(k_type);
return false;
#endif
}
static void ggml_vk_load_shaders(vk_device& device) {
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
@ -3525,121 +3534,96 @@ static void ggml_vk_load_shaders(vk_device& device) {
align, disable_robustness, require_full_subgroups, required_subgroup_size);
};
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \
FaCodePath path = fa.first.path; \
uint32_t Br = fa.first.Br; \
uint32_t Bc = fa.first.Bc; \
bool aligned = fa.first.aligned; \
bool f32acc = fa.first.f32acc; \
uint32_t fa_sgs = fa.first.subgroup_size; \
bool fa_ds = fa.first.subgroup_size == 0; \
if (path == FAPATH) { \
if (aligned) { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
} \
} else { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
} \
} \
} \
}
// FA scalar has two SPIR-V modules (MMQ vs non-MMQ); FA cm1 has one. K/V
// quant type is selected at runtime via the FaTypeK / FaTypeV spec constants.
if (device->fp16) {
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
for (auto &fa : device->pipeline_flash_attn_f32_f16) {
if (fa.first.path != FA_SCALAR) continue;
const uint32_t Br = fa.first.Br;
const uint32_t Bc = fa.first.Bc;
const bool aligned = fa.first.aligned;
const bool f32acc = fa.first.f32acc;
const uint32_t fa_sgs = fa.first.subgroup_size;
const bool fa_ds = fa.first.subgroup_size == 0;
const bool use_mmq = ggml_vk_fa_scalar_uses_mmq(device, fa.first.k_type);
const void * spv_data = nullptr;
size_t spv_size = 0;
if (use_mmq) {
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product && device->subgroup_clustered) {
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _int8)
} else
if (device->fp16) {
if (f32acc) { spv_data = flash_attn_f32_f16_int8_data; spv_size = flash_attn_f32_f16_int8_len; }
else { spv_data = flash_attn_f32_f16_f16acc_int8_data; spv_size = flash_attn_f32_f16_f16acc_int8_len; }
} else {
spv_data = flash_attn_f32_f16_fp32_int8_data;
spv_size = flash_attn_f32_f16_fp32_int8_len;
}
#endif
{
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )
}
} else {
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product && device->subgroup_clustered) {
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32_int8)
} else
#endif
{
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)
} else {
if (device->fp16) {
if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; }
else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; }
} else {
spv_data = flash_attn_f32_f16_fp32_data;
spv_size = flash_attn_f32_f16_fp32_len;
}
}
const char *name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16";
ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7,
sizeof(vk_flash_attn_push_constants), {Br, 1, 1},
get_fa_spec_constants(fa.first), aligned ? Bc : 1, true,
!fa_ds, !fa_ds ? fa_sgs : 0);
}
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->coopmat1_fa_support) {
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT1, _cm1)
}
#endif
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
#define CREATE_FA_CM2_MIXED() \
for (int fa_k_ty = 0; fa_k_ty < (int)GGML_TYPE_COUNT; ++fa_k_ty) { \
for (auto &fa : device->pipeline_flash_attn_f32_f16[fa_k_ty]) { \
FaCodePath path = fa.first.path; \
uint32_t Br = fa.first.Br; \
uint32_t Bc = fa.first.Bc; \
bool aligned = fa.first.aligned; \
bool f32acc = fa.first.f32acc; \
if (path == FA_COOPMAT2) { \
if (aligned) { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \
} \
} else { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \
} \
} \
} \
} \
for (auto &fa : device->pipeline_flash_attn_f32_f16) {
if (fa.first.path != FA_COOPMAT1) continue;
const uint32_t Br = fa.first.Br;
const uint32_t Bc = fa.first.Bc;
const bool aligned = fa.first.aligned;
const bool f32acc = fa.first.f32acc;
const uint32_t fa_sgs = fa.first.subgroup_size;
const bool fa_ds = fa.first.subgroup_size == 0;
const void * spv_data;
size_t spv_size;
if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; }
else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; }
const char *name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1";
ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7,
sizeof(vk_flash_attn_push_constants), {Br, 1, 1},
get_fa_spec_constants(fa.first), aligned ? Bc : 1, true,
!fa_ds, !fa_ds ? fa_sgs : 0);
}
}
#endif
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
for (auto &fa : device->pipeline_flash_attn_f32_f16) {
if (fa.first.path != FA_COOPMAT2) continue;
const uint32_t Br = fa.first.Br;
const uint32_t Bc = fa.first.Bc;
const bool aligned = fa.first.aligned;
const bool f32acc = fa.first.f32acc;
const void * spv_data;
size_t spv_size;
const char * name;
if (aligned) {
if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_aligned_f32acc_cm2"; }
else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_aligned_f16acc_cm2"; }
} else {
if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_f32acc_cm2"; }
else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_f16acc_cm2"; }
}
ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7,
sizeof(vk_flash_attn_push_constants), {Br, 1, 1},
get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, false, 0);
}
if (device->coopmat2) {
CREATE_FA_CM2_MIXED();
}
#undef CREATE_FA_CM2_MIXED
#endif
#undef CREATE_FA
const int mul_mat_id_param_count = 5;
@ -8940,8 +8924,9 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
}
}
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type) {
GGML_UNUSED(f32acc);
GGML_UNUSED(v_type);
// Needs to be kept up to date on shader changes
const uint32_t wg_size = params.workgroup_size;
const uint32_t Br = params.block_rows;
@ -8949,10 +8934,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
const bool mmq = device->integer_dot_product && device->subgroup_clustered &&
(kv_type == GGML_TYPE_Q4_0 || kv_type == GGML_TYPE_Q4_1 ||
kv_type == GGML_TYPE_Q5_0 || kv_type == GGML_TYPE_Q5_1 ||
kv_type == GGML_TYPE_Q8_0 || kv_type == GGML_TYPE_IQ4_NL);
const bool mmq = ggml_vk_fa_scalar_uses_mmq(device, k_type);
// tmpsh is overestimated slightly
const uint32_t tmpsh = wg_size * sizeof(float);
@ -8969,17 +8951,10 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
// kvsh uses D = HSV (K goes through kblocksh instead)
kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
// block_a_cache size depends on quant type
uint32_t block_a_size;
switch (kv_type) {
case GGML_TYPE_Q4_0: block_a_size = 4 * sizeof(uint32_t) + float_type_size; break;
case GGML_TYPE_Q4_1: block_a_size = 4 * sizeof(uint32_t) + 2 * float_type_size; break;
case GGML_TYPE_Q5_0: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + float_type_size; break;
case GGML_TYPE_Q5_1: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + 2 * float_type_size; break;
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL: block_a_size = 8 * sizeof(int32_t) + float_type_size; break;
default: block_a_size = 0; break;
}
// The mixed MMQ shader uses a superset block_a_cache that fits every
// FA-supported quant: int32_t qs[8] + uint32_t qh + FLOAT_TYPEV2 dm.
// Single-scale types leave dm.y unused; non-Q5_* leave qh unused.
const uint32_t block_a_size = 8 * sizeof(int32_t) + sizeof(uint32_t) + 2 * float_type_size;
kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size;
} else {
Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
@ -9117,10 +9092,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, v->type, f32acc);
if (tuning_params.path != FA_COOPMAT2) {
GGML_ASSERT(k->type == v->type);
}
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
@ -9164,7 +9135,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
{
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type];
auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16;
auto it = pipelines.find(fa_pipeline_state);
if (it != pipelines.end()) {
pipeline = it->second;
@ -15642,10 +15613,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
return false;
}
// mismatching K/V type is currently supported for coopmat2 only.
if (op->src[1]->type != op->src[2]->type && !coopmat2) {
return false;
}
auto fa_kv_ok = [coopmat2](ggml_type t) {
switch (t) {
case GGML_TYPE_F32:

View file

@ -22,6 +22,7 @@
#include "types.glsl"
#include "flash_attn_base.glsl"
#include "flash_attn_dequant.glsl"
const uint32_t HSK_per_thread = HSK / D_split;
const uint32_t HSV_per_thread = HSV / D_split;
@ -128,18 +129,20 @@ void main() {
Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals));
#if defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
if (buf_iqs == 0) {
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0);
}
#else // Q4_0, Q4_1, Q5_0, Q5_1
const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w;
const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8);
// Q8_0 K only needs (qd, _); the asymmetric Q4_*/Q5_* family also stores
// the row-sum scaled by qd, used in k_dot_correction.
if (FaTypeK == FA_TYPE_Q8_0) {
if (buf_iqs == 0) {
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0);
}
} else {
const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w;
const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8);
if (buf_iqs == 0) {
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd);
if (buf_iqs == 0) {
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd);
}
}
#endif
#endif
}
barrier();
@ -177,13 +180,9 @@ void main() {
// mo_offset will point to the tile starting at row i*Br and col 0
uint32_t mo_offset = mo_stride * i;
#if BLOCK_SIZE > 1
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
#else
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
// FaBlockBytesK/V == 2 for f16, 16 for f32, ggml block byte size for quants.
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / FaBlockBytesK;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / FaBlockBytesV;
uint32_t m_offset = gqa_iq1*KV;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
@ -257,21 +256,21 @@ void main() {
if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) {
FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
#endif
if (USE_DECODE_K) {
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d;
uint ib = coord / BLOCK_SIZE_K;
uint iqs = (coord % BLOCK_SIZE_K);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
} else {
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
}
}
kvsh[c * kvsh_stride + d] = K_Tf;
}
}
#else // MMQ
const uint ints_per_block = 8 / QUANT_R_MMQ;
const uint ints_per_block = 8u / fa_quant_r_mmq(FaTypeK);
const uint quant_iters = Bc * HSK / 32 * ints_per_block;
[[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) {
const uint32_t iqs = (idx + tid) % ints_per_block;
@ -310,15 +309,13 @@ void main() {
FLOAT_TYPEV4 K_Tf;
if (SHMEM_STAGING != 0) {
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
} else {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
} else if (USE_DECODE_K) {
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE_K;
uint iqs = (coord % BLOCK_SIZE_K);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
} else {
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
#endif
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf));
@ -335,15 +332,13 @@ void main() {
FLOAT_TYPEV4 K_Tf;
if (SHMEM_STAGING != 0) {
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
} else {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
} else if (USE_DECODE_K) {
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE_K;
uint iqs = (coord % BLOCK_SIZE_K);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
} else {
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
#endif
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf));
@ -366,72 +361,47 @@ void main() {
int32_t k_quants[d_per_step];
ACC_TYPEV2 k_dm;
// Q4_*/Q5_* take the block-8 fast path when one step covers a full
// block; Q8_0 always goes through the per-int get_k_qs* helpers
// (its qs is byte-packed, not nibble-packed).
const bool block8_fast = (d_per_step == 8) && (FaTypeK != FA_TYPE_Q8_0);
if (SHMEM_STAGING != 0) {
const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8;
const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx;
#if QUANT_AUXF == 1
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm, 0.0);
#else
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm);
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
if (d_per_step == 8) {
if (block8_fast) {
const bool has_qh = (FaTypeK == FA_TYPE_Q5_0) || (FaTypeK == FA_TYPE_Q5_1);
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
uint vui = kblocksh[buf_ib].qs[d];
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF;
uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF;
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
#endif
if (has_qh) {
uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF;
uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF;
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
}
}
} else
#endif
{
} else {
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d);
}
}
} else {
const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block);
const uint ib = coord / BLOCK_SIZE;
const uint iqs = (coord % BLOCK_SIZE);
const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d_tid * (HSK_per_thread / 4) + d_block);
const uint ib = coord / BLOCK_SIZE_K;
const uint iqs = (coord % BLOCK_SIZE_K);
#if QUANT_AUXF == 1
k_dm = ACC_TYPEV2(get_k_d(ib, k_offset), 0.0);
#else
k_dm = ACC_TYPEV2(get_k_dm(ib, k_offset));
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
if (d_per_step == 8) {
#if defined(DATA_A_Q5_0)
uint qh = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qh[0],
k_packed.k_data_packed16[k_offset + ib].qh[1]));
#elif defined(DATA_A_Q5_1)
uint qh = k_packed.k_data_packed16[k_offset + ib].qh;
#endif
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
#if defined(A_TYPE_PACKED32)
uint vui = k_packed32.k_data_packed32[k_offset + ib].qs[d];
#else
uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0],
k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1]));
#endif
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
uint qh_lo = (qh >> (d * 4)) & 0xF;
uint qh_hi = (qh >> (d * 4 + 16)) & 0xF;
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
#endif
k_dm = ACC_TYPEV2(get_k_scale(ib, k_offset));
if (block8_fast) {
fa_k_qs_block8 blk = get_k_qs_block8(ib, k_offset);
[[unroll]] for (uint32_t d = 0; d < 8; d++) {
k_quants[d] = blk.qs[d];
}
} else
#endif
{
} else {
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset);
}
@ -516,14 +486,14 @@ void main() {
if (idx + gl_WorkGroupSize.x <= Bc * HSV / 4 || c < Bc) {
FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
#endif
if (USE_DECODE_V) {
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d;
uint ib = coord / BLOCK_SIZE_V;
uint iqs = (coord % BLOCK_SIZE_V);
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
} else {
V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
}
}
kvsh[c * kvsh_stride + d] = V_Tf;
@ -547,15 +517,13 @@ void main() {
FLOAT_TYPEV4 Vf;
if (SHMEM_STAGING != 0) {
Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
} else {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
} else if (USE_DECODE_V) {
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE_V + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE_V;
uint iqs = (coord % BLOCK_SIZE_V);
Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
} else {
Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
#endif
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] += FLOAT_TYPEV4(Pf[r] * Vf);

View file

@ -87,176 +87,58 @@ layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];};
#define BINDING_IDX_K 0
#define BINDING_IDX_V 1
#if defined(DATA_A_F32)
layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed;
layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed;
#elif defined(A_TYPE_PACKED16)
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 1) readonly buffer K_PACKED32 {A_TYPE_PACKED32 k_data_packed32[];} k_packed32;
layout (binding = 2) readonly buffer V_PACKED32 {A_TYPE_PACKED32 v_data_packed32[];} v_packed32;
#endif
// FaTypeK / FaTypeV spec constant values. These mirror enum ggml_type so the
// host can pass the type directly. Keep in sync with ggml.h.
#define FA_TYPE_F32 0u
#define FA_TYPE_F16 1u
#define FA_TYPE_Q4_0 2u
#define FA_TYPE_Q4_1 3u
#define FA_TYPE_Q5_0 6u
#define FA_TYPE_Q5_1 7u
#define FA_TYPE_Q8_0 8u
#define FA_TYPE_Q1_0 41u
#ifndef BLOCK_SIZE
#define BLOCK_SIZE 1
#endif
#if defined(DATA_A_F32)
#undef BLOCK_SIZE
#define BLOCK_SIZE 4
#define BLOCK_BYTE_SIZE 16
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
// iqs is currently always zero in the flash attention shaders
if (binding_idx == BINDING_IDX_K) {
return FLOAT_TYPEV4(k_packed.k_data_packed[a_offset + ib]);
} else {
return FLOAT_TYPEV4(v_packed.v_data_packed[a_offset + ib]);
// Number of matrix elements per buffer block, derived from the K/V type spec
// constant. F32 is treated as a vec4 "block" of 4 floats. F16 uses block size 1
// and bypasses the dequant path entirely. Quants follow their ggml block sizes.
uint fa_block_elems(uint ty) {
switch (ty) {
case FA_TYPE_F32: return 4u;
case FA_TYPE_F16: return 1u;
case FA_TYPE_Q4_0: return uint(QUANT_K_Q4_0);
case FA_TYPE_Q4_1: return uint(QUANT_K_Q4_1);
case FA_TYPE_Q5_0: return uint(QUANT_K_Q5_0);
case FA_TYPE_Q5_1: return uint(QUANT_K_Q5_1);
case FA_TYPE_Q8_0: return uint(QUANT_K_Q8_0);
case FA_TYPE_Q1_0: return uint(QUANT_K_Q1_0); // cm2-only, harmless elsewhere
default: return 1u;
}
}
#endif
#if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18
#elif defined(DATA_A_Q4_1)
#define BLOCK_BYTE_SIZE 20
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
#ifdef DATA_A_Q4_1
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m);
#else
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f));
#endif
} else {
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
#ifdef DATA_A_Q4_1
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m);
#else
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f));
#endif
// QUANT_R_MMQ for FA-eligible K types. Q4_*/Q5_* store two nibbles per byte
// (R==2); Q8_0 stores one byte per element (R==1). Used to derive the number
// of int32s per 32-element block on the MMQ K path: ints_per_block == 8 / R.
uint fa_quant_r_mmq(uint ty) {
switch (ty) {
case FA_TYPE_Q4_0: return uint(QUANT_R_Q4_0);
case FA_TYPE_Q4_1: return uint(QUANT_R_Q4_1);
case FA_TYPE_Q5_0: return uint(QUANT_R_Q5_0);
case FA_TYPE_Q5_1: return uint(QUANT_R_Q5_1);
case FA_TYPE_Q8_0: return uint(QUANT_R_Q8_0);
default: return 1u;
}
}
#endif
#if defined(DATA_A_Q5_0)
#define BLOCK_BYTE_SIZE 22
#elif defined(DATA_A_Q5_1)
#define BLOCK_BYTE_SIZE 24
#endif
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
#ifdef DATA_A_Q5_1
uint qh = k_packed.k_data_packed16[a_offset + ib].qh;
#else
uint qh = uint(k_packed.k_data_packed16[a_offset + ib].qh[0]) | (uint(k_packed.k_data_packed16[a_offset + ib].qh[1]) << 16);
#endif
FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f);
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
#ifdef DATA_A_Q5_1
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m);
#else
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f));
#endif
} else {
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
#ifdef DATA_A_Q5_1
uint qh = v_packed.v_data_packed16[a_offset + ib].qh;
#else
uint qh = uint(v_packed.v_data_packed16[a_offset + ib].qh[0]) | (uint(v_packed.v_data_packed16[a_offset + ib].qh[1]) << 16);
#endif
FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f);
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
#ifdef DATA_A_Q5_1
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m);
#else
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f));
#endif
}
}
#endif
#if defined(DATA_A_IQ4_NL)
#define BLOCK_BYTE_SIZE 18
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(
kvalues_iq4nl[vui_lo & 0xF],
kvalues_iq4nl[(vui_lo >> 8) & 0xF],
kvalues_iq4nl[vui_hi & 0xF],
kvalues_iq4nl[(vui_hi >> 8) & 0xF]);
} else {
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(
kvalues_iq4nl[vui_lo & 0xF],
kvalues_iq4nl[(vui_lo >> 8) & 0xF],
kvalues_iq4nl[vui_hi & 0xF],
kvalues_iq4nl[(vui_hi >> 8) & 0xF]);
}
}
#endif
#if defined(DATA_A_Q8_0)
#define BLOCK_BYTE_SIZE 34
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);
} else {
const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);
}
}
#endif
// These can't be `const` globals because GLSL forbids function calls in global
// const initializers, even when the spec constants would let the driver fold
// them. Macros expand at the use site and fold after specialization.
#define BLOCK_SIZE_K fa_block_elems(FaTypeK)
#define BLOCK_SIZE_V fa_block_elems(FaTypeV)
// F16 reads f16 elements directly from the binding; everything else routes
// through dequantize4 / the MMQ helpers to unpack from the packed block layout.
#define USE_DECODE_K (FaTypeK != FA_TYPE_F16)
#define USE_DECODE_V (FaTypeV != FA_TYPE_F16)
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))

View file

@ -14,6 +14,7 @@
#include "types.glsl"
#include "flash_attn_base.glsl"
#include "flash_attn_dequant.glsl"
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
const uint32_t MatBr = 16;
@ -127,13 +128,9 @@ void main() {
// mo_offset will point to the tile starting at row i*Br and col 0
uint32_t mo_offset = mo_stride * i;
#if BLOCK_SIZE > 1
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
#else
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
// FaBlockBytesK/V == 2 for f16 (sizeof f16) and == 16 for f32 (vec4) and == ggml block size for quants.
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / FaBlockBytesK;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / FaBlockBytesV;
uint32_t m_offset = gqa_iq1*KV;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
@ -227,14 +224,14 @@ void main() {
if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) {
f16vec4 K_Tf = f16vec4(0);
if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
#endif
if (USE_DECODE_K) {
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d;
uint ib = coord / BLOCK_SIZE_K;
uint iqs = (coord % BLOCK_SIZE_K);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
} else {
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
}
}
kvsh[c * kvsh_stride + d] = K_Tf;
@ -256,47 +253,40 @@ void main() {
// staged through a Bc * MatBr size staging buffer.
// If K is not type f16, then it is always staged for dequantization.
if (SHMEM_STAGING == 0) {
#if BLOCK_SIZE == 1
if (KV_bounds_check || d * 16 + 16 > HSK) {
#endif
barrier();
[[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) {
uint32_t col_vec = (idx + tid) % (MatBr / 4);
uint32_t row = (idx + tid) / (MatBr / 4);
if (idx + tid < Bc * MatBr / 4) {
f16vec4 K_Tf = f16vec4(0);
if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
#endif
// For quants we always need to dequant into kvsh; for f16 we can load
// directly from global memory when alignment / bounds allow it.
const bool stage_k = USE_DECODE_K || KV_bounds_check || d * 16 + 16 > HSK;
if (stage_k) {
barrier();
[[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) {
uint32_t col_vec = (idx + tid) % (MatBr / 4);
uint32_t row = (idx + tid) / (MatBr / 4);
if (idx + tid < Bc * MatBr / 4) {
f16vec4 K_Tf = f16vec4(0);
if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) {
if (USE_DECODE_K) {
uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE_K + d * 16 + col_vec * 4;
uint ib = coord / BLOCK_SIZE_K;
uint iqs = (coord % BLOCK_SIZE_K);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
} else {
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
}
}
kvsh[row * kvsh_stride + col_vec] = K_Tf;
}
kvsh[row * kvsh_stride + col_vec] = K_Tf;
}
barrier();
}
barrier();
#if BLOCK_SIZE == 1
}
#endif
#if BLOCK_SIZE == 1
if (KV_bounds_check || d * 16 + 16 > HSK)
#endif
{
if (stage_k) {
uint coord = (gl_SubgroupID * MatBc) * kvsh_stride;
coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
#if BLOCK_SIZE == 1
else {
} else {
const uint coord = k_offset / 4 + (j * Bc + gl_SubgroupID * MatBc) * k_stride / 4 + d * 16 / 4;
coopMatLoad(KMat, data_kv4, coord, k_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
}
#endif
} else {
uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4;
coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
@ -397,14 +387,14 @@ void main() {
if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) {
f16vec4 V_Tf = f16vec4(0);
if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
#endif
if (USE_DECODE_V) {
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d;
uint ib = coord / BLOCK_SIZE_V;
uint iqs = (coord % BLOCK_SIZE_V);
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
} else {
V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
}
}
kvsh[c * kvsh_stride + d] = V_Tf;
@ -431,36 +421,33 @@ void main() {
// staged through a Bc * MatBr size staging buffer.
// If V is not type f16, then it is always staged for dequantization.
if (SHMEM_STAGING == 0) {
#if BLOCK_SIZE == 1
// For f16, only preload if not aligned
if (KV_bounds_check) {
#endif
[[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) {
const uint idx = i * gl_WorkGroupSize.x + tid;
const uint row = idx / v_cols;
const uint col = idx % v_cols;
// For quants we always preload via kvsh. For f16 we only preload when
// alignment / bounds force it (otherwise we coopMatLoad direct from data_vv4).
const bool stage_v = USE_DECODE_V || KV_bounds_check;
if (stage_v) {
[[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) {
const uint idx = i * gl_WorkGroupSize.x + tid;
const uint row = idx / v_cols;
const uint col = idx % v_cols;
const uint v_row = j * Bc + row;
const uint v_col = hsv_tile * MatBc * row_split + col * 4;
const uint v_row = j * Bc + row;
const uint v_col = hsv_tile * MatBc * row_split + col * 4;
const uint coord = v_row * v_stride * BLOCK_SIZE + v_col;
const uint ib = coord / BLOCK_SIZE;
const uint iqs = coord % BLOCK_SIZE;
const uint coord = v_row * v_stride * BLOCK_SIZE_V + v_col;
const uint ib = coord / BLOCK_SIZE_V;
const uint iqs = coord % BLOCK_SIZE_V;
if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
#if BLOCK_SIZE > 1
kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
#endif
} else {
kvsh[row * vsh_stride + col] = f16vec4(0.0f);
if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
if (USE_DECODE_V) {
kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
} else {
kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
}
} else {
kvsh[row * vsh_stride + col] = f16vec4(0.0f);
}
}
}
#if BLOCK_SIZE == 1
}
#endif
}
barrier();
@ -471,15 +458,12 @@ void main() {
coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
if (SHMEM_STAGING == 0) {
#if BLOCK_SIZE == 1
if (!KV_bounds_check) {
if (!USE_DECODE_V && !KV_bounds_check) {
// F16 values can be loaded directly from global memory
const uint v_tile_row = j * Bc + bc_chunk * MatBc;
const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
} else
#endif
{
} else {
const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
}

View file

@ -28,43 +28,28 @@ layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_
uint8_t raw[FaBlockBytesV];
};
uint fa_block_elems(uint ty) {
switch (ty) {
case 0u: return 4u; // GGML_TYPE_F32: vec4 block (matches decodeBufF32 / dequantFuncF32)
case 1u: return 1u; // GGML_TYPE_F16
case 2u: return uint(QUANT_K_Q4_0);
case 3u: return uint(QUANT_K_Q4_1);
case 6u: return uint(QUANT_K_Q5_0);
case 7u: return uint(QUANT_K_Q5_1);
case 8u: return uint(QUANT_K_Q8_0);
case 41u: return uint(QUANT_K_Q1_0);
default:
return 1u;
}
}
float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
switch (FaTypeK) {
case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock);
case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q4_0: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q4_1: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q5_0: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q5_1: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q8_0: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q1_0: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
default: return float16_t(0);
}
}
float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
switch (FaTypeV) {
case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock);
case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q4_0: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q4_1: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q5_0: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q5_1: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q8_0: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q1_0: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
default: return float16_t(0);
}
}

View file

@ -0,0 +1,123 @@
// Asymmetric K/V flash attention: aliased SSBO views of bindings 1 (K) and 2 (V)
// covering every supported FA element type, plus an uber dequantize4() that
// switches on FaTypeK / FaTypeV. After spec-constant specialization the driver
// folds away every path except the one matching the K/V type for this pipeline.
//
// Included by flash_attn.comp and flash_attn_cm1.comp. Not included by
// flash_attn_cm2.comp, which has its own buffer_reference-based decode path.
//
// We use macros (rather than per-quant decode functions taking a struct) on
// purpose: the FA shaders don't enable GL_EXT_shader_explicit_arithmetic_types_float16
// when FLOAT16 isn't defined, which makes float16-containing struct values
// illegal to return from / pass to functions. Macros expand inline where the
// float16 stays in storage and is converted to FLOAT_TYPE at use.
// F32 is fed as a vec4 "block" (4 floats), matching what dequant_funcs_cm2.glsl
// does for F32 in the cm2 shader. FaBlockBytesK/V == 16 for F32.
layout (binding = 1) readonly buffer K_PACKED_F32 { vec4 data[]; } k_packed_f32;
layout (binding = 2) readonly buffer V_PACKED_F32 { vec4 data[]; } v_packed_f32;
layout (binding = 1) readonly buffer K_PACKED_Q4_0 { block_q4_0_packed16 data[]; } k_packed_q4_0;
layout (binding = 2) readonly buffer V_PACKED_Q4_0 { block_q4_0_packed16 data[]; } v_packed_q4_0;
layout (binding = 1) readonly buffer K_PACKED_Q4_1 { block_q4_1_packed16 data[]; } k_packed_q4_1;
layout (binding = 2) readonly buffer V_PACKED_Q4_1 { block_q4_1_packed16 data[]; } v_packed_q4_1;
layout (binding = 1) readonly buffer K_PACKED_Q5_0 { block_q5_0_packed16 data[]; } k_packed_q5_0;
layout (binding = 2) readonly buffer V_PACKED_Q5_0 { block_q5_0_packed16 data[]; } v_packed_q5_0;
layout (binding = 1) readonly buffer K_PACKED_Q5_1 { block_q5_1_packed16 data[]; } k_packed_q5_1;
layout (binding = 2) readonly buffer V_PACKED_Q5_1 { block_q5_1_packed16 data[]; } v_packed_q5_1;
layout (binding = 1) readonly buffer K_PACKED_Q8_0 { block_q8_0_packed16 data[]; } k_packed_q8_0;
layout (binding = 2) readonly buffer V_PACKED_Q8_0 { block_q8_0_packed16 data[]; } v_packed_q8_0;
// Q4_1 and Q5_1 packed32 views: aliased to the same memory as the packed16
// views, used by the MMQ K-side hot path for fast 4-uint loads.
layout (binding = 1) readonly buffer K_PACKED_Q4_1_P32 { block_q4_1_packed32 data[]; } k_packed_q4_1_p32;
layout (binding = 1) readonly buffer K_PACKED_Q5_1_P32 { block_q5_1_packed32 data[]; } k_packed_q5_1_p32;
// Per-quant decode bodies are expanded once for the K view set and once for
// the V view set. The macros take the buffer name as a parameter.
#define FA_DEQUANT4_F32(BUF) \
return FLOAT_TYPEV4(BUF.data[a_offset + ib]);
#define FA_DEQUANT4_Q4_0(BUF) { \
uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \
uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \
uint shift = (iqs & 0x10) >> 2; \
vui_lo >>= shift; \
vui_hi >>= shift; \
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \
vui_hi & 0xF, (vui_hi >> 8) & 0xF); \
return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); \
}
#define FA_DEQUANT4_Q4_1(BUF) { \
uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \
uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \
uint shift = (iqs & 0x10) >> 2; \
vui_lo >>= shift; \
vui_hi >>= shift; \
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \
vui_hi & 0xF, (vui_hi >> 8) & 0xF); \
return FLOAT_TYPE(BUF.data[a_offset + ib].d) * nibbles \
+ FLOAT_TYPE(BUF.data[a_offset + ib].m); \
}
#define FA_DEQUANT4_Q5_0(BUF) { \
uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \
uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \
uint shift = (iqs & 0x10) >> 2; \
vui_lo >>= shift; \
vui_hi >>= shift; \
uint qh = uint(BUF.data[a_offset + ib].qh[0]) \
| (uint(BUF.data[a_offset + ib].qh[1]) << 16); \
FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, \
(qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) \
* FLOAT_TYPE(16.0f); \
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \
vui_hi & 0xF, (vui_hi >> 8) & 0xF); \
return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); \
}
#define FA_DEQUANT4_Q5_1(BUF) { \
uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \
uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \
uint shift = (iqs & 0x10) >> 2; \
vui_lo >>= shift; \
vui_hi >>= shift; \
uint qh = BUF.data[a_offset + ib].qh; \
FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, \
(qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) \
* FLOAT_TYPE(16.0f); \
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \
vui_hi & 0xF, (vui_hi >> 8) & 0xF); \
return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles + hb) \
+ FLOAT_TYPE(BUF.data[a_offset + ib].m); \
}
#define FA_DEQUANT4_Q8_0(BUF) { \
const i8vec2 v0 = unpack8(int32_t(BUF.data[a_offset + ib].qs[iqs / 2 ])).xy; \
const i8vec2 v1 = unpack8(int32_t(BUF.data[a_offset + ib].qs[iqs / 2 + 1])).xy; \
return FLOAT_TYPE(BUF.data[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); \
}
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
switch (FaTypeK) {
case FA_TYPE_F32: FA_DEQUANT4_F32 (k_packed_f32)
case FA_TYPE_Q4_0: FA_DEQUANT4_Q4_0(k_packed_q4_0)
case FA_TYPE_Q4_1: FA_DEQUANT4_Q4_1(k_packed_q4_1)
case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(k_packed_q5_0)
case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(k_packed_q5_1)
case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(k_packed_q8_0)
}
} else {
switch (FaTypeV) {
case FA_TYPE_F32: FA_DEQUANT4_F32 (v_packed_f32)
case FA_TYPE_Q4_0: FA_DEQUANT4_Q4_0(v_packed_q4_0)
case FA_TYPE_Q4_1: FA_DEQUANT4_Q4_1(v_packed_q4_1)
case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(v_packed_q5_0)
case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(v_packed_q5_1)
case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(v_packed_q8_0)
}
}
return FLOAT_TYPEV4(0);
}

View file

@ -1,149 +1,203 @@
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
// MMQ K-side helpers, asymmetric form. Each function dispatches on FaTypeK and
// reads from the matching aliased K binding declared in flash_attn_dequant.glsl.
// Spec-constant specialization folds the unused paths.
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
#ifdef DATA_A_Q4_0
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
#else
uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4];
#endif
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
return int32_t(vui & 0x0F0F0F0F);
switch (FaTypeK) {
case FA_TYPE_Q4_0: {
uint vui = pack32(u16vec2(k_packed_q4_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
k_packed_q4_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
return int32_t(vui & 0x0F0F0F0F);
}
case FA_TYPE_Q4_1: { // uses packed32 alias
uint vui = k_packed_q4_1_p32.data[a_offset + ib].qs[(iqs & 0xF) / 4];
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
return int32_t(vui & 0x0F0F0F0F);
}
case FA_TYPE_Q5_0: {
uint vui = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
k_packed_q5_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
uint qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qh[0],
k_packed_q5_0.data[a_offset + ib].qh[1]));
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
uint qh_bits = (qh >> iqs) & 0xF;
return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
}
case FA_TYPE_Q5_1: { // qs via packed32, qh via packed16
uint vui = k_packed_q5_1_p32.data[a_offset + ib].qs[(iqs & 0xF) / 4];
uint qh = k_packed_q5_1.data[a_offset + ib].qh;
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
uint qh_bits = (qh >> iqs) & 0xF;
return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
}
case FA_TYPE_Q8_0: {
return pack32(i16vec2(k_packed_q8_0.data[a_offset + ib].qs[iqs / 2],
k_packed_q8_0.data[a_offset + ib].qs[iqs / 2 + 1]));
}
default: return 0;
}
}
#endif
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
#ifdef DATA_A_Q5_0
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
uint qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qh[0],
k_packed.k_data_packed16[a_offset + ib].qh[1]));
#else
uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4];
uint qh = k_packed.k_data_packed16[a_offset + ib].qh;
#endif
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
uint qh_bits = (qh >> iqs) & 0xF;
return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
// Per-block scale/min, packed as (d, m). Single-scale types (Q4_0, Q5_0, Q8_0)
// return (d, 0) so call sites always see the same shape.
FLOAT_TYPEV2 get_k_scale(uint ib, uint a_offset) {
switch (FaTypeK) {
case FA_TYPE_Q4_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q4_0.data[a_offset + ib].d), 0.0);
case FA_TYPE_Q4_1: return FLOAT_TYPEV2(k_packed_q4_1_p32.data[a_offset + ib].dm);
case FA_TYPE_Q5_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q5_0.data[a_offset + ib].d), 0.0);
case FA_TYPE_Q5_1: return FLOAT_TYPEV2(k_packed_q5_1_p32.data[a_offset + ib].dm);
case FA_TYPE_Q8_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q8_0.data[a_offset + ib].d), 0.0);
default: return FLOAT_TYPEV2(0);
}
}
#endif
#if defined(DATA_A_Q8_0)
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
return pack32(i16vec2(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2], k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1]));
}
#endif
#if defined(DATA_A_IQ4_NL)
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
u8vec4 idx = unpack8(vui & 0x0F0F0F0F);
return pack32(i8vec4(kvalues_iq4nl_const[idx.x],
kvalues_iq4nl_const[idx.y],
kvalues_iq4nl_const[idx.z],
kvalues_iq4nl_const[idx.w]));
}
#endif
#if QUANT_AUXF == 1
FLOAT_TYPE get_k_d(uint ib, uint a_offset) {
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d);
}
#else
FLOAT_TYPEV2 get_k_dm(uint ib, uint a_offset) {
return FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + ib].dm);
}
#endif
void k_block_to_shmem(const uint buf_ib, const uint global_ib, const uint iqs, const uint a_offset) {
#if defined(DATA_A_Q4_0)
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
#elif defined(DATA_A_Q4_1)
kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs];
#elif defined(DATA_A_Q5_0)
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
if (iqs == 0) {
kblocksh[buf_ib].qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qh[0],
k_packed.k_data_packed16[a_offset + global_ib].qh[1]));
// kblocksh[].qs is int32_t for the unified MMQ struct; uint sources need
// explicit casts. The bit pattern is what we care about here -- the actual
// signed/unsigned interpretation happens downstream in the dot product.
switch (FaTypeK) {
case FA_TYPE_Q4_0: {
kblocksh[buf_ib].qs[iqs] = int32_t(pack32(u16vec2(k_packed_q4_0.data[a_offset + global_ib].qs[iqs * 2],
k_packed_q4_0.data[a_offset + global_ib].qs[iqs * 2 + 1])));
break;
}
case FA_TYPE_Q4_1: {
kblocksh[buf_ib].qs[iqs] = int32_t(k_packed_q4_1_p32.data[a_offset + global_ib].qs[iqs]);
break;
}
case FA_TYPE_Q5_0: {
kblocksh[buf_ib].qs[iqs] = int32_t(pack32(u16vec2(k_packed_q5_0.data[a_offset + global_ib].qs[iqs * 2],
k_packed_q5_0.data[a_offset + global_ib].qs[iqs * 2 + 1])));
if (iqs == 0) {
kblocksh[buf_ib].qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + global_ib].qh[0],
k_packed_q5_0.data[a_offset + global_ib].qh[1]));
}
break;
}
case FA_TYPE_Q5_1: {
kblocksh[buf_ib].qs[iqs] = int32_t(k_packed_q5_1_p32.data[a_offset + global_ib].qs[iqs]);
if (iqs == 0) {
kblocksh[buf_ib].qh = k_packed_q5_1.data[a_offset + global_ib].qh;
}
break;
}
case FA_TYPE_Q8_0: {
kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed_q8_0.data[a_offset + global_ib].qs[iqs * 2],
k_packed_q8_0.data[a_offset + global_ib].qs[iqs * 2 + 1]));
break;
}
}
#elif defined(DATA_A_Q5_1)
kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs];
if (iqs == 0) {
kblocksh[buf_ib].qh = k_packed.k_data_packed16[a_offset + global_ib].qh;
}
#elif defined(DATA_A_Q8_0)
kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
#elif defined(DATA_A_IQ4_NL)
const uint qs = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
kblocksh[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_iq4nl_const[i_a0.x], kvalues_iq4nl_const[i_a0.y],
kvalues_iq4nl_const[i_a0.z], kvalues_iq4nl_const[i_a0.w]));
kblocksh[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_iq4nl_const[i_a1.x], kvalues_iq4nl_const[i_a1.y],
kvalues_iq4nl_const[i_a1.z], kvalues_iq4nl_const[i_a1.w]));
#endif
if (iqs == 0) {
#if QUANT_AUXF == 1
kblocksh[buf_ib].dm = FLOAT_TYPE(k_packed.k_data_packed16[a_offset + global_ib].d);
#else
kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + global_ib].dm);
#endif
// Q4_0/Q5_0/Q8_0 store dm.x = d; Q4_1/Q5_1 store dm = (d, m) pair.
switch (FaTypeK) {
case FA_TYPE_Q4_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q4_0.data[a_offset + global_ib].d), 0.0); break;
case FA_TYPE_Q4_1: kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed_q4_1_p32.data[a_offset + global_ib].dm); break;
case FA_TYPE_Q5_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q5_0.data[a_offset + global_ib].d), 0.0); break;
case FA_TYPE_Q5_1: kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed_q5_1_p32.data[a_offset + global_ib].dm); break;
case FA_TYPE_Q8_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q8_0.data[a_offset + global_ib].d), 0.0); break;
}
}
}
// d_per_step==8 hot path: read one full 32-element block worth of nibble-packed
// int32 quants. Equivalent to 8 calls to get_k_qs(ib, d*4, a_offset) but reads
// qh (Q5_*) and runs pack32 (Q4_0/Q5_0) once per block instead of per nibble
// quad. iqs is always 0 in this path (hsk4 % 8 == 0 implies block-aligned).
// Q8_0 takes the generic get_k_qs path because its qs layout (i8 pairs) doesn't
// share this nibble shape.
//
// Returned via a struct so the caller's k_quants array (sized from spec
// constants) doesn't need to match a fixed[8] out-parameter type.
struct fa_k_qs_block8 {
int32_t qs[8];
};
fa_k_qs_block8 get_k_qs_block8(uint ib, uint a_offset) {
fa_k_qs_block8 r;
uint qh = 0;
if (FaTypeK == FA_TYPE_Q5_0) {
qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qh[0],
k_packed_q5_0.data[a_offset + ib].qh[1]));
} else if (FaTypeK == FA_TYPE_Q5_1) {
qh = k_packed_q5_1.data[a_offset + ib].qh;
}
const bool has_qh = (FaTypeK == FA_TYPE_Q5_0) || (FaTypeK == FA_TYPE_Q5_1);
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
uint vui = 0;
switch (FaTypeK) {
case FA_TYPE_Q4_0: { // packed16
vui = pack32(u16vec2(k_packed_q4_0.data[a_offset + ib].qs[d * 2 + 0],
k_packed_q4_0.data[a_offset + ib].qs[d * 2 + 1]));
break;
}
case FA_TYPE_Q4_1: { // packed32 alias
vui = k_packed_q4_1_p32.data[a_offset + ib].qs[d];
break;
}
case FA_TYPE_Q5_0: { // packed16
vui = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qs[d * 2 + 0],
k_packed_q5_0.data[a_offset + ib].qs[d * 2 + 1]));
break;
}
case FA_TYPE_Q5_1: { // packed32 alias
vui = k_packed_q5_1_p32.data[a_offset + ib].qs[d];
break;
}
}
r.qs[d ] = int32_t( vui & 0x0F0F0F0F);
r.qs[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
if (has_qh) {
uint qh_lo = (qh >> (d * 4)) & 0xFu;
uint qh_hi = (qh >> (d * 4 + 16)) & 0xFu;
r.qs[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
r.qs[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
}
}
return r;
}
int32_t get_k_qs_shmem(const uint buf_ib, const uint pos) {
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
uint sub = pos % 4;
uint shift = ((pos % 8) >= 4) ? 4 : 0;
return int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F);
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
uint sub = pos % 4;
uint shift = ((pos % 8) >= 4) ? 4 : 0;
int32_t result = int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F);
uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4)) & 0xF;
return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
#elif defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
return kblocksh[buf_ib].qs[pos];
#endif
switch (FaTypeK) {
case FA_TYPE_Q4_0:
case FA_TYPE_Q4_1: {
uint sub = pos % 4;
uint shift = ((pos % 8) >= 4) ? 4u : 0u;
return int32_t((uint(kblocksh[buf_ib].qs[sub]) >> shift) & 0x0F0F0F0Fu);
}
case FA_TYPE_Q5_0:
case FA_TYPE_Q5_1: {
uint sub = pos % 4;
uint shift = ((pos % 8) >= 4) ? 4u : 0u;
int32_t result = int32_t((uint(kblocksh[buf_ib].qs[sub]) >> shift) & 0x0F0F0F0Fu);
uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4u)) & 0xFu;
return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
}
case FA_TYPE_Q8_0: {
return kblocksh[buf_ib].qs[pos];
}
default: return 0;
}
}
ACC_TYPE k_dot_correction(const uint qib, const ACC_TYPEV2 k_dm) {
#if defined(DATA_A_Q4_0)
return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
#elif defined(DATA_A_Q5_0)
return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
#elif defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
return ACC_TYPE(Qf[qib].ds.y) * k_dm.y;
#else
return ACC_TYPE(0.0);
#endif
switch (FaTypeK) {
case FA_TYPE_Q4_0: return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
case FA_TYPE_Q5_0: return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
case FA_TYPE_Q4_1:
case FA_TYPE_Q5_1: return ACC_TYPE(Qf[qib].ds.y) * k_dm.y;
default: return ACC_TYPE(0.0);
}
}
void k_block_to_shmem_zero(const uint buf_ib, const uint iqs) {
kblocksh[buf_ib].qs[iqs] = 0;
#if defined(DATA_A_IQ4_NL)
kblocksh[buf_ib].qs[iqs + 4] = 0;
#endif
if (iqs == 0) {
#if QUANT_AUXF == 1
kblocksh[buf_ib].dm = FLOAT_TYPE(0.0f);
#else
kblocksh[buf_ib].dm = FLOAT_TYPEV2(0.0f);
#endif
}
}

View file

@ -1,4 +1,13 @@
#if defined(DATA_A_Q4_0)
#if defined(FA_MMQ_MIXED)
// Mixed-K flash attention MMQ: superset cache that fits Q4_0/Q4_1/Q5_0/Q5_1/Q8_0.
// Q4_*/Q5_* only use qs[0..3] and (for Q5_*) qh. Q8_0 uses qs[0..7]. Single-scale
// types (Q4_0/Q5_0/Q8_0) leave dm.y unused.
struct block_a_cache {
int32_t qs[8];
uint32_t qh;
FLOAT_TYPEV2 dm;
};
#elif defined(DATA_A_Q4_0)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[16/4];

View file

@ -643,42 +643,22 @@ void process_shaders() {
if (fp16) {
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
string_to_spv("flash_attn_f32_f16_mixed", "flash_attn_cm2.comp",
string_to_spv("flash_attn_f32_f16", "flash_attn_cm2.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc);
#endif
}
for (const auto& tname : type_names) {
if (tname == "bf16") continue;
if (fp16) {
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
} else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
}
string_to_spv("flash_attn_f32_f16", "flash_attn_cm1.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
#endif
}
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc);
} else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (tname != "f32") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }, {"MMQ", "1"}}), fp16, false, false, f16acc, "_int8");
}
#endif
}
}
string_to_spv("flash_attn_f32_f16", "flash_attn.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
string_to_spv("flash_attn_f32_f16", "flash_attn.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"MMQ", "1"}, {"FA_MMQ_MIXED", "1"}}), fp16, false, false, f16acc, "_int8");
#endif
}
}