mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-19 16:31:59 +00:00
vulkan: Support asymmetric FA in scalar/mmq/coopmat1 paths (#22589)
This commit is contained in:
parent
8cef8201a1
commit
dd9280a664
9 changed files with 632 additions and 680 deletions
|
|
@ -855,7 +855,7 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
|
||||
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
|
||||
|
||||
std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
|
||||
std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16;
|
||||
|
||||
std::map<std::pair<uint32_t, uint32_t>, vk_pipeline> pipeline_fa_mask_opt;
|
||||
|
||||
|
|
@ -2933,10 +2933,10 @@ struct vk_fa_tuning_params {
|
|||
}
|
||||
};
|
||||
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type);
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type);
|
||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
|
||||
|
||||
static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
|
||||
static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
|
||||
|
||||
vk_fa_tuning_params result{};
|
||||
result.path = FA_SCALAR;
|
||||
|
|
@ -2988,7 +2988,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device,
|
|||
|
||||
result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
|
||||
|
||||
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, kv_type)) {
|
||||
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, k_type, v_type)) {
|
||||
result.block_rows /= 2;
|
||||
}
|
||||
|
||||
|
|
@ -3011,10 +3011,11 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device,
|
|||
return result;
|
||||
}
|
||||
|
||||
static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
|
||||
static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
|
||||
GGML_UNUSED(n_rows);
|
||||
GGML_UNUSED(n_kv);
|
||||
GGML_UNUSED(kv_type);
|
||||
GGML_UNUSED(k_type);
|
||||
GGML_UNUSED(v_type);
|
||||
GGML_UNUSED(f32acc);
|
||||
|
||||
vk_fa_tuning_params result{};
|
||||
|
|
@ -3070,12 +3071,6 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device
|
|||
}
|
||||
|
||||
static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
|
||||
// Mixed K/V is only implemented on the coopmat2 (flash_attn_cm2) path; never use scalar/cm1.
|
||||
if (k_type != v_type) {
|
||||
GGML_ASSERT(device->coopmat2);
|
||||
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
|
||||
}
|
||||
|
||||
FaCodePath path = device->coopmat2 ? FA_COOPMAT2 :
|
||||
device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
|
||||
|
||||
|
|
@ -3087,7 +3082,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
|
|||
if (path == FA_COOPMAT1) {
|
||||
bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) ||
|
||||
(!f32acc && device->coopmat_support_16x16x16_f16acc);
|
||||
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
|
||||
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
|
||||
bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc);
|
||||
|
||||
if (!shape_ok || !shmem_ok) {
|
||||
|
|
@ -3107,9 +3102,9 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
|
|||
|
||||
switch (path) {
|
||||
case FA_SCALAR:
|
||||
return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
|
||||
return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
|
||||
case FA_COOPMAT1:
|
||||
return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
|
||||
return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
|
||||
case FA_COOPMAT2:
|
||||
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
|
||||
default:
|
||||
|
|
@ -3279,6 +3274,20 @@ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_dev
|
|||
return 0; // If no matching configuration is found
|
||||
}
|
||||
|
||||
// Whether scalar flash attention will use the MMQ path for the given k_type.
|
||||
static bool ggml_vk_fa_scalar_uses_mmq(const vk_device& device, ggml_type k_type) {
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
return device->integer_dot_product && device->subgroup_clustered &&
|
||||
(k_type == GGML_TYPE_Q4_0 || k_type == GGML_TYPE_Q4_1 ||
|
||||
k_type == GGML_TYPE_Q5_0 || k_type == GGML_TYPE_Q5_1 ||
|
||||
k_type == GGML_TYPE_Q8_0);
|
||||
#else
|
||||
GGML_UNUSED(device);
|
||||
GGML_UNUSED(k_type);
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_vk_load_shaders(vk_device& device) {
|
||||
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
|
||||
|
||||
|
|
@ -3525,121 +3534,96 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
align, disable_robustness, require_full_subgroups, required_subgroup_size);
|
||||
};
|
||||
|
||||
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
||||
for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \
|
||||
FaCodePath path = fa.first.path; \
|
||||
uint32_t Br = fa.first.Br; \
|
||||
uint32_t Bc = fa.first.Bc; \
|
||||
bool aligned = fa.first.aligned; \
|
||||
bool f32acc = fa.first.f32acc; \
|
||||
uint32_t fa_sgs = fa.first.subgroup_size; \
|
||||
bool fa_ds = fa.first.subgroup_size == 0; \
|
||||
if (path == FAPATH) { \
|
||||
if (aligned) { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
||||
} \
|
||||
} else { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
// FA scalar has two SPIR-V modules (MMQ vs non-MMQ); FA cm1 has one. K/V
|
||||
// quant type is selected at runtime via the FaTypeK / FaTypeV spec constants.
|
||||
|
||||
if (device->fp16) {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
||||
for (auto &fa : device->pipeline_flash_attn_f32_f16) {
|
||||
if (fa.first.path != FA_SCALAR) continue;
|
||||
const uint32_t Br = fa.first.Br;
|
||||
const uint32_t Bc = fa.first.Bc;
|
||||
const bool aligned = fa.first.aligned;
|
||||
const bool f32acc = fa.first.f32acc;
|
||||
const uint32_t fa_sgs = fa.first.subgroup_size;
|
||||
const bool fa_ds = fa.first.subgroup_size == 0;
|
||||
|
||||
const bool use_mmq = ggml_vk_fa_scalar_uses_mmq(device, fa.first.k_type);
|
||||
const void * spv_data = nullptr;
|
||||
size_t spv_size = 0;
|
||||
if (use_mmq) {
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product && device->subgroup_clustered) {
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _int8)
|
||||
} else
|
||||
if (device->fp16) {
|
||||
if (f32acc) { spv_data = flash_attn_f32_f16_int8_data; spv_size = flash_attn_f32_f16_int8_len; }
|
||||
else { spv_data = flash_attn_f32_f16_f16acc_int8_data; spv_size = flash_attn_f32_f16_f16acc_int8_len; }
|
||||
} else {
|
||||
spv_data = flash_attn_f32_f16_fp32_int8_data;
|
||||
spv_size = flash_attn_f32_f16_fp32_int8_len;
|
||||
}
|
||||
#endif
|
||||
{
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )
|
||||
}
|
||||
} else {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product && device->subgroup_clustered) {
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32_int8)
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)
|
||||
} else {
|
||||
if (device->fp16) {
|
||||
if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; }
|
||||
else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; }
|
||||
} else {
|
||||
spv_data = flash_attn_f32_f16_fp32_data;
|
||||
spv_size = flash_attn_f32_f16_fp32_len;
|
||||
}
|
||||
}
|
||||
const char *name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16";
|
||||
ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7,
|
||||
sizeof(vk_flash_attn_push_constants), {Br, 1, 1},
|
||||
get_fa_spec_constants(fa.first), aligned ? Bc : 1, true,
|
||||
!fa_ds, !fa_ds ? fa_sgs : 0);
|
||||
}
|
||||
|
||||
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (device->coopmat1_fa_support) {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT1, _cm1)
|
||||
}
|
||||
#endif
|
||||
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
#define CREATE_FA_CM2_MIXED() \
|
||||
for (int fa_k_ty = 0; fa_k_ty < (int)GGML_TYPE_COUNT; ++fa_k_ty) { \
|
||||
for (auto &fa : device->pipeline_flash_attn_f32_f16[fa_k_ty]) { \
|
||||
FaCodePath path = fa.first.path; \
|
||||
uint32_t Br = fa.first.Br; \
|
||||
uint32_t Bc = fa.first.Bc; \
|
||||
bool aligned = fa.first.aligned; \
|
||||
bool f32acc = fa.first.f32acc; \
|
||||
if (path == FA_COOPMAT2) { \
|
||||
if (aligned) { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \
|
||||
} \
|
||||
} else { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
for (auto &fa : device->pipeline_flash_attn_f32_f16) {
|
||||
if (fa.first.path != FA_COOPMAT1) continue;
|
||||
const uint32_t Br = fa.first.Br;
|
||||
const uint32_t Bc = fa.first.Bc;
|
||||
const bool aligned = fa.first.aligned;
|
||||
const bool f32acc = fa.first.f32acc;
|
||||
const uint32_t fa_sgs = fa.first.subgroup_size;
|
||||
const bool fa_ds = fa.first.subgroup_size == 0;
|
||||
|
||||
const void * spv_data;
|
||||
size_t spv_size;
|
||||
if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; }
|
||||
else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; }
|
||||
const char *name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1";
|
||||
ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7,
|
||||
sizeof(vk_flash_attn_push_constants), {Br, 1, 1},
|
||||
get_fa_spec_constants(fa.first), aligned ? Bc : 1, true,
|
||||
!fa_ds, !fa_ds ? fa_sgs : 0);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (device->coopmat2) {
|
||||
for (auto &fa : device->pipeline_flash_attn_f32_f16) {
|
||||
if (fa.first.path != FA_COOPMAT2) continue;
|
||||
const uint32_t Br = fa.first.Br;
|
||||
const uint32_t Bc = fa.first.Bc;
|
||||
const bool aligned = fa.first.aligned;
|
||||
const bool f32acc = fa.first.f32acc;
|
||||
|
||||
const void * spv_data;
|
||||
size_t spv_size;
|
||||
const char * name;
|
||||
if (aligned) {
|
||||
if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_aligned_f32acc_cm2"; }
|
||||
else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_aligned_f16acc_cm2"; }
|
||||
} else {
|
||||
if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_f32acc_cm2"; }
|
||||
else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_f16acc_cm2"; }
|
||||
}
|
||||
ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7,
|
||||
sizeof(vk_flash_attn_push_constants), {Br, 1, 1},
|
||||
get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, false, 0);
|
||||
}
|
||||
if (device->coopmat2) {
|
||||
CREATE_FA_CM2_MIXED();
|
||||
}
|
||||
#undef CREATE_FA_CM2_MIXED
|
||||
#endif
|
||||
#undef CREATE_FA
|
||||
|
||||
const int mul_mat_id_param_count = 5;
|
||||
|
||||
|
|
@ -8940,8 +8924,9 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
}
|
||||
}
|
||||
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type) {
|
||||
GGML_UNUSED(f32acc);
|
||||
GGML_UNUSED(v_type);
|
||||
// Needs to be kept up to date on shader changes
|
||||
const uint32_t wg_size = params.workgroup_size;
|
||||
const uint32_t Br = params.block_rows;
|
||||
|
|
@ -8949,10 +8934,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
|
|||
|
||||
const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
|
||||
|
||||
const bool mmq = device->integer_dot_product && device->subgroup_clustered &&
|
||||
(kv_type == GGML_TYPE_Q4_0 || kv_type == GGML_TYPE_Q4_1 ||
|
||||
kv_type == GGML_TYPE_Q5_0 || kv_type == GGML_TYPE_Q5_1 ||
|
||||
kv_type == GGML_TYPE_Q8_0 || kv_type == GGML_TYPE_IQ4_NL);
|
||||
const bool mmq = ggml_vk_fa_scalar_uses_mmq(device, k_type);
|
||||
|
||||
// tmpsh is overestimated slightly
|
||||
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||
|
|
@ -8969,17 +8951,10 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
|
|||
// kvsh uses D = HSV (K goes through kblocksh instead)
|
||||
kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
|
||||
|
||||
// block_a_cache size depends on quant type
|
||||
uint32_t block_a_size;
|
||||
switch (kv_type) {
|
||||
case GGML_TYPE_Q4_0: block_a_size = 4 * sizeof(uint32_t) + float_type_size; break;
|
||||
case GGML_TYPE_Q4_1: block_a_size = 4 * sizeof(uint32_t) + 2 * float_type_size; break;
|
||||
case GGML_TYPE_Q5_0: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + float_type_size; break;
|
||||
case GGML_TYPE_Q5_1: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + 2 * float_type_size; break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_IQ4_NL: block_a_size = 8 * sizeof(int32_t) + float_type_size; break;
|
||||
default: block_a_size = 0; break;
|
||||
}
|
||||
// The mixed MMQ shader uses a superset block_a_cache that fits every
|
||||
// FA-supported quant: int32_t qs[8] + uint32_t qh + FLOAT_TYPEV2 dm.
|
||||
// Single-scale types leave dm.y unused; non-Q5_* leave qh unused.
|
||||
const uint32_t block_a_size = 8 * sizeof(int32_t) + sizeof(uint32_t) + 2 * float_type_size;
|
||||
kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size;
|
||||
} else {
|
||||
Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
|
||||
|
|
@ -9117,10 +9092,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
|
||||
tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, v->type, f32acc);
|
||||
|
||||
if (tuning_params.path != FA_COOPMAT2) {
|
||||
GGML_ASSERT(k->type == v->type);
|
||||
}
|
||||
|
||||
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
|
||||
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
||||
uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
|
||||
|
|
@ -9164,7 +9135,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
|
||||
{
|
||||
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
|
||||
auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type];
|
||||
auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16;
|
||||
auto it = pipelines.find(fa_pipeline_state);
|
||||
if (it != pipelines.end()) {
|
||||
pipeline = it->second;
|
||||
|
|
@ -15642,10 +15613,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
// mismatching K/V type is currently supported for coopmat2 only.
|
||||
if (op->src[1]->type != op->src[2]->type && !coopmat2) {
|
||||
return false;
|
||||
}
|
||||
auto fa_kv_ok = [coopmat2](ggml_type t) {
|
||||
switch (t) {
|
||||
case GGML_TYPE_F32:
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
#include "types.glsl"
|
||||
#include "flash_attn_base.glsl"
|
||||
#include "flash_attn_dequant.glsl"
|
||||
|
||||
const uint32_t HSK_per_thread = HSK / D_split;
|
||||
const uint32_t HSV_per_thread = HSV / D_split;
|
||||
|
|
@ -128,18 +129,20 @@ void main() {
|
|||
|
||||
Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals));
|
||||
|
||||
#if defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
|
||||
if (buf_iqs == 0) {
|
||||
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0);
|
||||
}
|
||||
#else // Q4_0, Q4_1, Q5_0, Q5_1
|
||||
const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w;
|
||||
const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8);
|
||||
// Q8_0 K only needs (qd, _); the asymmetric Q4_*/Q5_* family also stores
|
||||
// the row-sum scaled by qd, used in k_dot_correction.
|
||||
if (FaTypeK == FA_TYPE_Q8_0) {
|
||||
if (buf_iqs == 0) {
|
||||
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0);
|
||||
}
|
||||
} else {
|
||||
const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w;
|
||||
const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8);
|
||||
|
||||
if (buf_iqs == 0) {
|
||||
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd);
|
||||
if (buf_iqs == 0) {
|
||||
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
barrier();
|
||||
|
|
@ -177,13 +180,9 @@ void main() {
|
|||
// mo_offset will point to the tile starting at row i*Br and col 0
|
||||
uint32_t mo_offset = mo_stride * i;
|
||||
|
||||
#if BLOCK_SIZE > 1
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
|
||||
#else
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
||||
#endif
|
||||
// FaBlockBytesK/V == 2 for f16, 16 for f32, ggml block byte size for quants.
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / FaBlockBytesK;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / FaBlockBytesV;
|
||||
uint32_t m_offset = gqa_iq1*KV;
|
||||
if (p.nem2 != 1 || p.nem3 != 1) {
|
||||
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
|
||||
|
|
@ -257,21 +256,21 @@ void main() {
|
|||
if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) {
|
||||
FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0);
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||
#endif
|
||||
if (USE_DECODE_K) {
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE_K;
|
||||
uint iqs = (coord % BLOCK_SIZE_K);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
} else {
|
||||
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||
}
|
||||
}
|
||||
|
||||
kvsh[c * kvsh_stride + d] = K_Tf;
|
||||
}
|
||||
}
|
||||
#else // MMQ
|
||||
const uint ints_per_block = 8 / QUANT_R_MMQ;
|
||||
const uint ints_per_block = 8u / fa_quant_r_mmq(FaTypeK);
|
||||
const uint quant_iters = Bc * HSK / 32 * ints_per_block;
|
||||
[[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) {
|
||||
const uint32_t iqs = (idx + tid) % ints_per_block;
|
||||
|
|
@ -310,15 +309,13 @@ void main() {
|
|||
FLOAT_TYPEV4 K_Tf;
|
||||
if (SHMEM_STAGING != 0) {
|
||||
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
|
||||
} else {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
} else if (USE_DECODE_K) {
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE_K;
|
||||
uint iqs = (coord % BLOCK_SIZE_K);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
} else {
|
||||
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
||||
#endif
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf));
|
||||
|
|
@ -335,15 +332,13 @@ void main() {
|
|||
FLOAT_TYPEV4 K_Tf;
|
||||
if (SHMEM_STAGING != 0) {
|
||||
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
|
||||
} else {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
} else if (USE_DECODE_K) {
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE_K;
|
||||
uint iqs = (coord % BLOCK_SIZE_K);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
} else {
|
||||
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
||||
#endif
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf));
|
||||
|
|
@ -366,72 +361,47 @@ void main() {
|
|||
int32_t k_quants[d_per_step];
|
||||
ACC_TYPEV2 k_dm;
|
||||
|
||||
// Q4_*/Q5_* take the block-8 fast path when one step covers a full
|
||||
// block; Q8_0 always goes through the per-int get_k_qs* helpers
|
||||
// (its qs is byte-packed, not nibble-packed).
|
||||
const bool block8_fast = (d_per_step == 8) && (FaTypeK != FA_TYPE_Q8_0);
|
||||
|
||||
if (SHMEM_STAGING != 0) {
|
||||
const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8;
|
||||
const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx;
|
||||
#if QUANT_AUXF == 1
|
||||
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm, 0.0);
|
||||
#else
|
||||
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm);
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
if (d_per_step == 8) {
|
||||
if (block8_fast) {
|
||||
const bool has_qh = (FaTypeK == FA_TYPE_Q5_0) || (FaTypeK == FA_TYPE_Q5_1);
|
||||
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
|
||||
uint vui = kblocksh[buf_ib].qs[d];
|
||||
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
|
||||
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
|
||||
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF;
|
||||
uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF;
|
||||
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
|
||||
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
|
||||
#endif
|
||||
if (has_qh) {
|
||||
uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF;
|
||||
uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF;
|
||||
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
|
||||
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
|
||||
}
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
} else {
|
||||
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
|
||||
k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block);
|
||||
const uint ib = coord / BLOCK_SIZE;
|
||||
const uint iqs = (coord % BLOCK_SIZE);
|
||||
const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d_tid * (HSK_per_thread / 4) + d_block);
|
||||
const uint ib = coord / BLOCK_SIZE_K;
|
||||
const uint iqs = (coord % BLOCK_SIZE_K);
|
||||
|
||||
#if QUANT_AUXF == 1
|
||||
k_dm = ACC_TYPEV2(get_k_d(ib, k_offset), 0.0);
|
||||
#else
|
||||
k_dm = ACC_TYPEV2(get_k_dm(ib, k_offset));
|
||||
#endif
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
if (d_per_step == 8) {
|
||||
#if defined(DATA_A_Q5_0)
|
||||
uint qh = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qh[0],
|
||||
k_packed.k_data_packed16[k_offset + ib].qh[1]));
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
uint qh = k_packed.k_data_packed16[k_offset + ib].qh;
|
||||
#endif
|
||||
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
|
||||
#if defined(A_TYPE_PACKED32)
|
||||
uint vui = k_packed32.k_data_packed32[k_offset + ib].qs[d];
|
||||
#else
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0],
|
||||
k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1]));
|
||||
#endif
|
||||
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
|
||||
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
|
||||
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
uint qh_lo = (qh >> (d * 4)) & 0xF;
|
||||
uint qh_hi = (qh >> (d * 4 + 16)) & 0xF;
|
||||
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
|
||||
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
|
||||
#endif
|
||||
k_dm = ACC_TYPEV2(get_k_scale(ib, k_offset));
|
||||
|
||||
if (block8_fast) {
|
||||
fa_k_qs_block8 blk = get_k_qs_block8(ib, k_offset);
|
||||
[[unroll]] for (uint32_t d = 0; d < 8; d++) {
|
||||
k_quants[d] = blk.qs[d];
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
} else {
|
||||
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
|
||||
k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset);
|
||||
}
|
||||
|
|
@ -516,14 +486,14 @@ void main() {
|
|||
if (idx + gl_WorkGroupSize.x <= Bc * HSV / 4 || c < Bc) {
|
||||
FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0);
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
#else
|
||||
V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
|
||||
#endif
|
||||
if (USE_DECODE_V) {
|
||||
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE_V;
|
||||
uint iqs = (coord % BLOCK_SIZE_V);
|
||||
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
} else {
|
||||
V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
|
||||
}
|
||||
}
|
||||
|
||||
kvsh[c * kvsh_stride + d] = V_Tf;
|
||||
|
|
@ -547,15 +517,13 @@ void main() {
|
|||
FLOAT_TYPEV4 Vf;
|
||||
if (SHMEM_STAGING != 0) {
|
||||
Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
|
||||
} else {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
} else if (USE_DECODE_V) {
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE_V + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE_V;
|
||||
uint iqs = (coord % BLOCK_SIZE_V);
|
||||
Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
#else
|
||||
} else {
|
||||
Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
|
||||
#endif
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] += FLOAT_TYPEV4(Pf[r] * Vf);
|
||||
|
|
|
|||
|
|
@ -87,176 +87,58 @@ layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];};
|
|||
|
||||
#define BINDING_IDX_K 0
|
||||
#define BINDING_IDX_V 1
|
||||
#if defined(DATA_A_F32)
|
||||
layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed;
|
||||
layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed;
|
||||
#elif defined(A_TYPE_PACKED16)
|
||||
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
|
||||
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
|
||||
#endif
|
||||
|
||||
#if defined(A_TYPE_PACKED32)
|
||||
layout (binding = 1) readonly buffer K_PACKED32 {A_TYPE_PACKED32 k_data_packed32[];} k_packed32;
|
||||
layout (binding = 2) readonly buffer V_PACKED32 {A_TYPE_PACKED32 v_data_packed32[];} v_packed32;
|
||||
#endif
|
||||
// FaTypeK / FaTypeV spec constant values. These mirror enum ggml_type so the
|
||||
// host can pass the type directly. Keep in sync with ggml.h.
|
||||
#define FA_TYPE_F32 0u
|
||||
#define FA_TYPE_F16 1u
|
||||
#define FA_TYPE_Q4_0 2u
|
||||
#define FA_TYPE_Q4_1 3u
|
||||
#define FA_TYPE_Q5_0 6u
|
||||
#define FA_TYPE_Q5_1 7u
|
||||
#define FA_TYPE_Q8_0 8u
|
||||
#define FA_TYPE_Q1_0 41u
|
||||
|
||||
#ifndef BLOCK_SIZE
|
||||
#define BLOCK_SIZE 1
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_F32)
|
||||
#undef BLOCK_SIZE
|
||||
#define BLOCK_SIZE 4
|
||||
#define BLOCK_BYTE_SIZE 16
|
||||
|
||||
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
// iqs is currently always zero in the flash attention shaders
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
return FLOAT_TYPEV4(k_packed.k_data_packed[a_offset + ib]);
|
||||
} else {
|
||||
return FLOAT_TYPEV4(v_packed.v_data_packed[a_offset + ib]);
|
||||
// Number of matrix elements per buffer block, derived from the K/V type spec
|
||||
// constant. F32 is treated as a vec4 "block" of 4 floats. F16 uses block size 1
|
||||
// and bypasses the dequant path entirely. Quants follow their ggml block sizes.
|
||||
uint fa_block_elems(uint ty) {
|
||||
switch (ty) {
|
||||
case FA_TYPE_F32: return 4u;
|
||||
case FA_TYPE_F16: return 1u;
|
||||
case FA_TYPE_Q4_0: return uint(QUANT_K_Q4_0);
|
||||
case FA_TYPE_Q4_1: return uint(QUANT_K_Q4_1);
|
||||
case FA_TYPE_Q5_0: return uint(QUANT_K_Q5_0);
|
||||
case FA_TYPE_Q5_1: return uint(QUANT_K_Q5_1);
|
||||
case FA_TYPE_Q8_0: return uint(QUANT_K_Q8_0);
|
||||
case FA_TYPE_Q1_0: return uint(QUANT_K_Q1_0); // cm2-only, harmless elsewhere
|
||||
default: return 1u;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#define BLOCK_BYTE_SIZE 18
|
||||
#elif defined(DATA_A_Q4_1)
|
||||
#define BLOCK_BYTE_SIZE 20
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
|
||||
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
|
||||
#ifdef DATA_A_Q4_1
|
||||
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m);
|
||||
#else
|
||||
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f));
|
||||
#endif
|
||||
} else {
|
||||
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
|
||||
#ifdef DATA_A_Q4_1
|
||||
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m);
|
||||
#else
|
||||
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f));
|
||||
#endif
|
||||
// QUANT_R_MMQ for FA-eligible K types. Q4_*/Q5_* store two nibbles per byte
|
||||
// (R==2); Q8_0 stores one byte per element (R==1). Used to derive the number
|
||||
// of int32s per 32-element block on the MMQ K path: ints_per_block == 8 / R.
|
||||
uint fa_quant_r_mmq(uint ty) {
|
||||
switch (ty) {
|
||||
case FA_TYPE_Q4_0: return uint(QUANT_R_Q4_0);
|
||||
case FA_TYPE_Q4_1: return uint(QUANT_R_Q4_1);
|
||||
case FA_TYPE_Q5_0: return uint(QUANT_R_Q5_0);
|
||||
case FA_TYPE_Q5_1: return uint(QUANT_R_Q5_1);
|
||||
case FA_TYPE_Q8_0: return uint(QUANT_R_Q8_0);
|
||||
default: return 1u;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_0)
|
||||
#define BLOCK_BYTE_SIZE 22
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
#define BLOCK_BYTE_SIZE 24
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
#ifdef DATA_A_Q5_1
|
||||
uint qh = k_packed.k_data_packed16[a_offset + ib].qh;
|
||||
#else
|
||||
uint qh = uint(k_packed.k_data_packed16[a_offset + ib].qh[0]) | (uint(k_packed.k_data_packed16[a_offset + ib].qh[1]) << 16);
|
||||
#endif
|
||||
FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f);
|
||||
|
||||
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
|
||||
#ifdef DATA_A_Q5_1
|
||||
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m);
|
||||
#else
|
||||
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f));
|
||||
#endif
|
||||
} else {
|
||||
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
#ifdef DATA_A_Q5_1
|
||||
uint qh = v_packed.v_data_packed16[a_offset + ib].qh;
|
||||
#else
|
||||
uint qh = uint(v_packed.v_data_packed16[a_offset + ib].qh[0]) | (uint(v_packed.v_data_packed16[a_offset + ib].qh[1]) << 16);
|
||||
#endif
|
||||
FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f);
|
||||
|
||||
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
|
||||
#ifdef DATA_A_Q5_1
|
||||
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m);
|
||||
#else
|
||||
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
#define BLOCK_BYTE_SIZE 18
|
||||
|
||||
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(
|
||||
kvalues_iq4nl[vui_lo & 0xF],
|
||||
kvalues_iq4nl[(vui_lo >> 8) & 0xF],
|
||||
kvalues_iq4nl[vui_hi & 0xF],
|
||||
kvalues_iq4nl[(vui_hi >> 8) & 0xF]);
|
||||
} else {
|
||||
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(
|
||||
kvalues_iq4nl[vui_lo & 0xF],
|
||||
kvalues_iq4nl[(vui_lo >> 8) & 0xF],
|
||||
kvalues_iq4nl[vui_hi & 0xF],
|
||||
kvalues_iq4nl[(vui_hi >> 8) & 0xF]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#if defined(DATA_A_Q8_0)
|
||||
#define BLOCK_BYTE_SIZE 34
|
||||
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||
|
||||
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);
|
||||
} else {
|
||||
const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||
|
||||
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
// These can't be `const` globals because GLSL forbids function calls in global
|
||||
// const initializers, even when the spec constants would let the driver fold
|
||||
// them. Macros expand at the use site and fold after specialization.
|
||||
#define BLOCK_SIZE_K fa_block_elems(FaTypeK)
|
||||
#define BLOCK_SIZE_V fa_block_elems(FaTypeV)
|
||||
// F16 reads f16 elements directly from the binding; everything else routes
|
||||
// through dequantize4 / the MMQ helpers to unpack from the packed block layout.
|
||||
#define USE_DECODE_K (FaTypeK != FA_TYPE_F16)
|
||||
#define USE_DECODE_V (FaTypeV != FA_TYPE_F16)
|
||||
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
#include "types.glsl"
|
||||
#include "flash_attn_base.glsl"
|
||||
#include "flash_attn_dequant.glsl"
|
||||
|
||||
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
|
||||
const uint32_t MatBr = 16;
|
||||
|
|
@ -127,13 +128,9 @@ void main() {
|
|||
// mo_offset will point to the tile starting at row i*Br and col 0
|
||||
uint32_t mo_offset = mo_stride * i;
|
||||
|
||||
#if BLOCK_SIZE > 1
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
|
||||
#else
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
||||
#endif
|
||||
// FaBlockBytesK/V == 2 for f16 (sizeof f16) and == 16 for f32 (vec4) and == ggml block size for quants.
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / FaBlockBytesK;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / FaBlockBytesV;
|
||||
uint32_t m_offset = gqa_iq1*KV;
|
||||
if (p.nem2 != 1 || p.nem3 != 1) {
|
||||
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
|
||||
|
|
@ -227,14 +224,14 @@ void main() {
|
|||
if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) {
|
||||
f16vec4 K_Tf = f16vec4(0);
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||
#endif
|
||||
if (USE_DECODE_K) {
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE_K;
|
||||
uint iqs = (coord % BLOCK_SIZE_K);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
} else {
|
||||
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||
}
|
||||
}
|
||||
|
||||
kvsh[c * kvsh_stride + d] = K_Tf;
|
||||
|
|
@ -256,47 +253,40 @@ void main() {
|
|||
// staged through a Bc * MatBr size staging buffer.
|
||||
// If K is not type f16, then it is always staged for dequantization.
|
||||
if (SHMEM_STAGING == 0) {
|
||||
#if BLOCK_SIZE == 1
|
||||
if (KV_bounds_check || d * 16 + 16 > HSK) {
|
||||
#endif
|
||||
barrier();
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t col_vec = (idx + tid) % (MatBr / 4);
|
||||
uint32_t row = (idx + tid) / (MatBr / 4);
|
||||
if (idx + tid < Bc * MatBr / 4) {
|
||||
f16vec4 K_Tf = f16vec4(0);
|
||||
if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
|
||||
#endif
|
||||
// For quants we always need to dequant into kvsh; for f16 we can load
|
||||
// directly from global memory when alignment / bounds allow it.
|
||||
const bool stage_k = USE_DECODE_K || KV_bounds_check || d * 16 + 16 > HSK;
|
||||
if (stage_k) {
|
||||
barrier();
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t col_vec = (idx + tid) % (MatBr / 4);
|
||||
uint32_t row = (idx + tid) / (MatBr / 4);
|
||||
if (idx + tid < Bc * MatBr / 4) {
|
||||
f16vec4 K_Tf = f16vec4(0);
|
||||
if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) {
|
||||
if (USE_DECODE_K) {
|
||||
uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE_K + d * 16 + col_vec * 4;
|
||||
uint ib = coord / BLOCK_SIZE_K;
|
||||
uint iqs = (coord % BLOCK_SIZE_K);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
} else {
|
||||
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
|
||||
}
|
||||
}
|
||||
|
||||
kvsh[row * kvsh_stride + col_vec] = K_Tf;
|
||||
}
|
||||
|
||||
kvsh[row * kvsh_stride + col_vec] = K_Tf;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
barrier();
|
||||
#if BLOCK_SIZE == 1
|
||||
}
|
||||
#endif
|
||||
|
||||
#if BLOCK_SIZE == 1
|
||||
if (KV_bounds_check || d * 16 + 16 > HSK)
|
||||
#endif
|
||||
{
|
||||
if (stage_k) {
|
||||
uint coord = (gl_SubgroupID * MatBc) * kvsh_stride;
|
||||
coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
#if BLOCK_SIZE == 1
|
||||
else {
|
||||
} else {
|
||||
const uint coord = k_offset / 4 + (j * Bc + gl_SubgroupID * MatBc) * k_stride / 4 + d * 16 / 4;
|
||||
coopMatLoad(KMat, data_kv4, coord, k_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4;
|
||||
coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
|
|
@ -397,14 +387,14 @@ void main() {
|
|||
if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) {
|
||||
f16vec4 V_Tf = f16vec4(0);
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
#else
|
||||
V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
|
||||
#endif
|
||||
if (USE_DECODE_V) {
|
||||
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE_V;
|
||||
uint iqs = (coord % BLOCK_SIZE_V);
|
||||
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
} else {
|
||||
V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
|
||||
}
|
||||
}
|
||||
|
||||
kvsh[c * kvsh_stride + d] = V_Tf;
|
||||
|
|
@ -431,36 +421,33 @@ void main() {
|
|||
// staged through a Bc * MatBr size staging buffer.
|
||||
// If V is not type f16, then it is always staged for dequantization.
|
||||
if (SHMEM_STAGING == 0) {
|
||||
#if BLOCK_SIZE == 1
|
||||
// For f16, only preload if not aligned
|
||||
if (KV_bounds_check) {
|
||||
#endif
|
||||
[[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) {
|
||||
const uint idx = i * gl_WorkGroupSize.x + tid;
|
||||
const uint row = idx / v_cols;
|
||||
const uint col = idx % v_cols;
|
||||
// For quants we always preload via kvsh. For f16 we only preload when
|
||||
// alignment / bounds force it (otherwise we coopMatLoad direct from data_vv4).
|
||||
const bool stage_v = USE_DECODE_V || KV_bounds_check;
|
||||
if (stage_v) {
|
||||
[[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) {
|
||||
const uint idx = i * gl_WorkGroupSize.x + tid;
|
||||
const uint row = idx / v_cols;
|
||||
const uint col = idx % v_cols;
|
||||
|
||||
const uint v_row = j * Bc + row;
|
||||
const uint v_col = hsv_tile * MatBc * row_split + col * 4;
|
||||
const uint v_row = j * Bc + row;
|
||||
const uint v_col = hsv_tile * MatBc * row_split + col * 4;
|
||||
|
||||
const uint coord = v_row * v_stride * BLOCK_SIZE + v_col;
|
||||
const uint ib = coord / BLOCK_SIZE;
|
||||
const uint iqs = coord % BLOCK_SIZE;
|
||||
const uint coord = v_row * v_stride * BLOCK_SIZE_V + v_col;
|
||||
const uint ib = coord / BLOCK_SIZE_V;
|
||||
const uint iqs = coord % BLOCK_SIZE_V;
|
||||
|
||||
if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
|
||||
#if BLOCK_SIZE > 1
|
||||
kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
#else
|
||||
kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
|
||||
#endif
|
||||
} else {
|
||||
kvsh[row * vsh_stride + col] = f16vec4(0.0f);
|
||||
if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
|
||||
if (USE_DECODE_V) {
|
||||
kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
} else {
|
||||
kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
|
||||
}
|
||||
} else {
|
||||
kvsh[row * vsh_stride + col] = f16vec4(0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if BLOCK_SIZE == 1
|
||||
}
|
||||
#endif
|
||||
}
|
||||
barrier();
|
||||
|
||||
|
|
@ -471,15 +458,12 @@ void main() {
|
|||
coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
if (SHMEM_STAGING == 0) {
|
||||
#if BLOCK_SIZE == 1
|
||||
if (!KV_bounds_check) {
|
||||
if (!USE_DECODE_V && !KV_bounds_check) {
|
||||
// F16 values can be loaded directly from global memory
|
||||
const uint v_tile_row = j * Bc + bc_chunk * MatBc;
|
||||
const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
|
||||
coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
} else {
|
||||
const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
|
||||
coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,43 +28,28 @@ layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_
|
|||
uint8_t raw[FaBlockBytesV];
|
||||
};
|
||||
|
||||
uint fa_block_elems(uint ty) {
|
||||
switch (ty) {
|
||||
case 0u: return 4u; // GGML_TYPE_F32: vec4 block (matches decodeBufF32 / dequantFuncF32)
|
||||
case 1u: return 1u; // GGML_TYPE_F16
|
||||
case 2u: return uint(QUANT_K_Q4_0);
|
||||
case 3u: return uint(QUANT_K_Q4_1);
|
||||
case 6u: return uint(QUANT_K_Q5_0);
|
||||
case 7u: return uint(QUANT_K_Q5_1);
|
||||
case 8u: return uint(QUANT_K_Q8_0);
|
||||
case 41u: return uint(QUANT_K_Q1_0);
|
||||
default:
|
||||
return 1u;
|
||||
}
|
||||
}
|
||||
|
||||
float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
|
||||
switch (FaTypeK) {
|
||||
case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock);
|
||||
case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
|
||||
case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
|
||||
case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
|
||||
case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
|
||||
case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
|
||||
case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q4_0: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q4_1: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q5_0: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q5_1: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q8_0: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q1_0: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
|
||||
default: return float16_t(0);
|
||||
}
|
||||
}
|
||||
|
||||
float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
|
||||
switch (FaTypeV) {
|
||||
case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock);
|
||||
case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
|
||||
case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
|
||||
case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
|
||||
case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
|
||||
case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
|
||||
case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q4_0: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q4_1: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q5_0: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q5_1: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q8_0: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q1_0: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
|
||||
default: return float16_t(0);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
123
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl
Normal file
123
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
// Asymmetric K/V flash attention: aliased SSBO views of bindings 1 (K) and 2 (V)
|
||||
// covering every supported FA element type, plus an uber dequantize4() that
|
||||
// switches on FaTypeK / FaTypeV. After spec-constant specialization the driver
|
||||
// folds away every path except the one matching the K/V type for this pipeline.
|
||||
//
|
||||
// Included by flash_attn.comp and flash_attn_cm1.comp. Not included by
|
||||
// flash_attn_cm2.comp, which has its own buffer_reference-based decode path.
|
||||
//
|
||||
// We use macros (rather than per-quant decode functions taking a struct) on
|
||||
// purpose: the FA shaders don't enable GL_EXT_shader_explicit_arithmetic_types_float16
|
||||
// when FLOAT16 isn't defined, which makes float16-containing struct values
|
||||
// illegal to return from / pass to functions. Macros expand inline where the
|
||||
// float16 stays in storage and is converted to FLOAT_TYPE at use.
|
||||
|
||||
// F32 is fed as a vec4 "block" (4 floats), matching what dequant_funcs_cm2.glsl
|
||||
// does for F32 in the cm2 shader. FaBlockBytesK/V == 16 for F32.
|
||||
layout (binding = 1) readonly buffer K_PACKED_F32 { vec4 data[]; } k_packed_f32;
|
||||
layout (binding = 2) readonly buffer V_PACKED_F32 { vec4 data[]; } v_packed_f32;
|
||||
|
||||
layout (binding = 1) readonly buffer K_PACKED_Q4_0 { block_q4_0_packed16 data[]; } k_packed_q4_0;
|
||||
layout (binding = 2) readonly buffer V_PACKED_Q4_0 { block_q4_0_packed16 data[]; } v_packed_q4_0;
|
||||
layout (binding = 1) readonly buffer K_PACKED_Q4_1 { block_q4_1_packed16 data[]; } k_packed_q4_1;
|
||||
layout (binding = 2) readonly buffer V_PACKED_Q4_1 { block_q4_1_packed16 data[]; } v_packed_q4_1;
|
||||
layout (binding = 1) readonly buffer K_PACKED_Q5_0 { block_q5_0_packed16 data[]; } k_packed_q5_0;
|
||||
layout (binding = 2) readonly buffer V_PACKED_Q5_0 { block_q5_0_packed16 data[]; } v_packed_q5_0;
|
||||
layout (binding = 1) readonly buffer K_PACKED_Q5_1 { block_q5_1_packed16 data[]; } k_packed_q5_1;
|
||||
layout (binding = 2) readonly buffer V_PACKED_Q5_1 { block_q5_1_packed16 data[]; } v_packed_q5_1;
|
||||
layout (binding = 1) readonly buffer K_PACKED_Q8_0 { block_q8_0_packed16 data[]; } k_packed_q8_0;
|
||||
layout (binding = 2) readonly buffer V_PACKED_Q8_0 { block_q8_0_packed16 data[]; } v_packed_q8_0;
|
||||
|
||||
// Q4_1 and Q5_1 packed32 views: aliased to the same memory as the packed16
|
||||
// views, used by the MMQ K-side hot path for fast 4-uint loads.
|
||||
layout (binding = 1) readonly buffer K_PACKED_Q4_1_P32 { block_q4_1_packed32 data[]; } k_packed_q4_1_p32;
|
||||
layout (binding = 1) readonly buffer K_PACKED_Q5_1_P32 { block_q5_1_packed32 data[]; } k_packed_q5_1_p32;
|
||||
|
||||
// Per-quant decode bodies are expanded once for the K view set and once for
|
||||
// the V view set. The macros take the buffer name as a parameter.
|
||||
#define FA_DEQUANT4_F32(BUF) \
|
||||
return FLOAT_TYPEV4(BUF.data[a_offset + ib]);
|
||||
|
||||
#define FA_DEQUANT4_Q4_0(BUF) { \
|
||||
uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \
|
||||
uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \
|
||||
uint shift = (iqs & 0x10) >> 2; \
|
||||
vui_lo >>= shift; \
|
||||
vui_hi >>= shift; \
|
||||
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \
|
||||
vui_hi & 0xF, (vui_hi >> 8) & 0xF); \
|
||||
return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); \
|
||||
}
|
||||
|
||||
#define FA_DEQUANT4_Q4_1(BUF) { \
|
||||
uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \
|
||||
uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \
|
||||
uint shift = (iqs & 0x10) >> 2; \
|
||||
vui_lo >>= shift; \
|
||||
vui_hi >>= shift; \
|
||||
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \
|
||||
vui_hi & 0xF, (vui_hi >> 8) & 0xF); \
|
||||
return FLOAT_TYPE(BUF.data[a_offset + ib].d) * nibbles \
|
||||
+ FLOAT_TYPE(BUF.data[a_offset + ib].m); \
|
||||
}
|
||||
|
||||
#define FA_DEQUANT4_Q5_0(BUF) { \
|
||||
uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \
|
||||
uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \
|
||||
uint shift = (iqs & 0x10) >> 2; \
|
||||
vui_lo >>= shift; \
|
||||
vui_hi >>= shift; \
|
||||
uint qh = uint(BUF.data[a_offset + ib].qh[0]) \
|
||||
| (uint(BUF.data[a_offset + ib].qh[1]) << 16); \
|
||||
FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, \
|
||||
(qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) \
|
||||
* FLOAT_TYPE(16.0f); \
|
||||
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \
|
||||
vui_hi & 0xF, (vui_hi >> 8) & 0xF); \
|
||||
return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); \
|
||||
}
|
||||
|
||||
#define FA_DEQUANT4_Q5_1(BUF) { \
|
||||
uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \
|
||||
uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \
|
||||
uint shift = (iqs & 0x10) >> 2; \
|
||||
vui_lo >>= shift; \
|
||||
vui_hi >>= shift; \
|
||||
uint qh = BUF.data[a_offset + ib].qh; \
|
||||
FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, \
|
||||
(qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) \
|
||||
* FLOAT_TYPE(16.0f); \
|
||||
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \
|
||||
vui_hi & 0xF, (vui_hi >> 8) & 0xF); \
|
||||
return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles + hb) \
|
||||
+ FLOAT_TYPE(BUF.data[a_offset + ib].m); \
|
||||
}
|
||||
|
||||
#define FA_DEQUANT4_Q8_0(BUF) { \
|
||||
const i8vec2 v0 = unpack8(int32_t(BUF.data[a_offset + ib].qs[iqs / 2 ])).xy; \
|
||||
const i8vec2 v1 = unpack8(int32_t(BUF.data[a_offset + ib].qs[iqs / 2 + 1])).xy; \
|
||||
return FLOAT_TYPE(BUF.data[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); \
|
||||
}
|
||||
|
||||
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
switch (FaTypeK) {
|
||||
case FA_TYPE_F32: FA_DEQUANT4_F32 (k_packed_f32)
|
||||
case FA_TYPE_Q4_0: FA_DEQUANT4_Q4_0(k_packed_q4_0)
|
||||
case FA_TYPE_Q4_1: FA_DEQUANT4_Q4_1(k_packed_q4_1)
|
||||
case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(k_packed_q5_0)
|
||||
case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(k_packed_q5_1)
|
||||
case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(k_packed_q8_0)
|
||||
}
|
||||
} else {
|
||||
switch (FaTypeV) {
|
||||
case FA_TYPE_F32: FA_DEQUANT4_F32 (v_packed_f32)
|
||||
case FA_TYPE_Q4_0: FA_DEQUANT4_Q4_0(v_packed_q4_0)
|
||||
case FA_TYPE_Q4_1: FA_DEQUANT4_Q4_1(v_packed_q4_1)
|
||||
case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(v_packed_q5_0)
|
||||
case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(v_packed_q5_1)
|
||||
case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(v_packed_q8_0)
|
||||
}
|
||||
}
|
||||
return FLOAT_TYPEV4(0);
|
||||
}
|
||||
|
|
@ -1,149 +1,203 @@
|
|||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
|
||||
// MMQ K-side helpers, asymmetric form. Each function dispatches on FaTypeK and
|
||||
// reads from the matching aliased K binding declared in flash_attn_dequant.glsl.
|
||||
// Spec-constant specialization folds the unused paths.
|
||||
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
#ifdef DATA_A_Q4_0
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
|
||||
#else
|
||||
uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4];
|
||||
#endif
|
||||
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
|
||||
return int32_t(vui & 0x0F0F0F0F);
|
||||
switch (FaTypeK) {
|
||||
case FA_TYPE_Q4_0: {
|
||||
uint vui = pack32(u16vec2(k_packed_q4_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
|
||||
k_packed_q4_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
return int32_t(vui & 0x0F0F0F0F);
|
||||
}
|
||||
case FA_TYPE_Q4_1: { // uses packed32 alias
|
||||
uint vui = k_packed_q4_1_p32.data[a_offset + ib].qs[(iqs & 0xF) / 4];
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
return int32_t(vui & 0x0F0F0F0F);
|
||||
}
|
||||
case FA_TYPE_Q5_0: {
|
||||
uint vui = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
|
||||
k_packed_q5_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
|
||||
uint qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qh[0],
|
||||
k_packed_q5_0.data[a_offset + ib].qh[1]));
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
uint qh_bits = (qh >> iqs) & 0xF;
|
||||
return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
|
||||
}
|
||||
case FA_TYPE_Q5_1: { // qs via packed32, qh via packed16
|
||||
uint vui = k_packed_q5_1_p32.data[a_offset + ib].qs[(iqs & 0xF) / 4];
|
||||
uint qh = k_packed_q5_1.data[a_offset + ib].qh;
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
uint qh_bits = (qh >> iqs) & 0xF;
|
||||
return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
|
||||
}
|
||||
case FA_TYPE_Q8_0: {
|
||||
return pack32(i16vec2(k_packed_q8_0.data[a_offset + ib].qs[iqs / 2],
|
||||
k_packed_q8_0.data[a_offset + ib].qs[iqs / 2 + 1]));
|
||||
}
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
#ifdef DATA_A_Q5_0
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
|
||||
uint qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qh[0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qh[1]));
|
||||
#else
|
||||
uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4];
|
||||
uint qh = k_packed.k_data_packed16[a_offset + ib].qh;
|
||||
#endif
|
||||
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
|
||||
uint qh_bits = (qh >> iqs) & 0xF;
|
||||
return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
|
||||
// Per-block scale/min, packed as (d, m). Single-scale types (Q4_0, Q5_0, Q8_0)
|
||||
// return (d, 0) so call sites always see the same shape.
|
||||
FLOAT_TYPEV2 get_k_scale(uint ib, uint a_offset) {
|
||||
switch (FaTypeK) {
|
||||
case FA_TYPE_Q4_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q4_0.data[a_offset + ib].d), 0.0);
|
||||
case FA_TYPE_Q4_1: return FLOAT_TYPEV2(k_packed_q4_1_p32.data[a_offset + ib].dm);
|
||||
case FA_TYPE_Q5_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q5_0.data[a_offset + ib].d), 0.0);
|
||||
case FA_TYPE_Q5_1: return FLOAT_TYPEV2(k_packed_q5_1_p32.data[a_offset + ib].dm);
|
||||
case FA_TYPE_Q8_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q8_0.data[a_offset + ib].d), 0.0);
|
||||
default: return FLOAT_TYPEV2(0);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
return pack32(i16vec2(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2], k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1]));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
|
||||
u8vec4 idx = unpack8(vui & 0x0F0F0F0F);
|
||||
return pack32(i8vec4(kvalues_iq4nl_const[idx.x],
|
||||
kvalues_iq4nl_const[idx.y],
|
||||
kvalues_iq4nl_const[idx.z],
|
||||
kvalues_iq4nl_const[idx.w]));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if QUANT_AUXF == 1
|
||||
FLOAT_TYPE get_k_d(uint ib, uint a_offset) {
|
||||
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d);
|
||||
}
|
||||
#else
|
||||
FLOAT_TYPEV2 get_k_dm(uint ib, uint a_offset) {
|
||||
return FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + ib].dm);
|
||||
}
|
||||
#endif
|
||||
|
||||
void k_block_to_shmem(const uint buf_ib, const uint global_ib, const uint iqs, const uint a_offset) {
|
||||
#if defined(DATA_A_Q4_0)
|
||||
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
#elif defined(DATA_A_Q4_1)
|
||||
kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs];
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
if (iqs == 0) {
|
||||
kblocksh[buf_ib].qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qh[0],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qh[1]));
|
||||
// kblocksh[].qs is int32_t for the unified MMQ struct; uint sources need
|
||||
// explicit casts. The bit pattern is what we care about here -- the actual
|
||||
// signed/unsigned interpretation happens downstream in the dot product.
|
||||
switch (FaTypeK) {
|
||||
case FA_TYPE_Q4_0: {
|
||||
kblocksh[buf_ib].qs[iqs] = int32_t(pack32(u16vec2(k_packed_q4_0.data[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed_q4_0.data[a_offset + global_ib].qs[iqs * 2 + 1])));
|
||||
break;
|
||||
}
|
||||
case FA_TYPE_Q4_1: {
|
||||
kblocksh[buf_ib].qs[iqs] = int32_t(k_packed_q4_1_p32.data[a_offset + global_ib].qs[iqs]);
|
||||
break;
|
||||
}
|
||||
case FA_TYPE_Q5_0: {
|
||||
kblocksh[buf_ib].qs[iqs] = int32_t(pack32(u16vec2(k_packed_q5_0.data[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed_q5_0.data[a_offset + global_ib].qs[iqs * 2 + 1])));
|
||||
if (iqs == 0) {
|
||||
kblocksh[buf_ib].qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + global_ib].qh[0],
|
||||
k_packed_q5_0.data[a_offset + global_ib].qh[1]));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case FA_TYPE_Q5_1: {
|
||||
kblocksh[buf_ib].qs[iqs] = int32_t(k_packed_q5_1_p32.data[a_offset + global_ib].qs[iqs]);
|
||||
if (iqs == 0) {
|
||||
kblocksh[buf_ib].qh = k_packed_q5_1.data[a_offset + global_ib].qh;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case FA_TYPE_Q8_0: {
|
||||
kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed_q8_0.data[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed_q8_0.data[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
break;
|
||||
}
|
||||
}
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs];
|
||||
if (iqs == 0) {
|
||||
kblocksh[buf_ib].qh = k_packed.k_data_packed16[a_offset + global_ib].qh;
|
||||
}
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
const uint qs = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
|
||||
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
|
||||
kblocksh[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_iq4nl_const[i_a0.x], kvalues_iq4nl_const[i_a0.y],
|
||||
kvalues_iq4nl_const[i_a0.z], kvalues_iq4nl_const[i_a0.w]));
|
||||
kblocksh[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_iq4nl_const[i_a1.x], kvalues_iq4nl_const[i_a1.y],
|
||||
kvalues_iq4nl_const[i_a1.z], kvalues_iq4nl_const[i_a1.w]));
|
||||
#endif
|
||||
|
||||
if (iqs == 0) {
|
||||
#if QUANT_AUXF == 1
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPE(k_packed.k_data_packed16[a_offset + global_ib].d);
|
||||
#else
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + global_ib].dm);
|
||||
#endif
|
||||
// Q4_0/Q5_0/Q8_0 store dm.x = d; Q4_1/Q5_1 store dm = (d, m) pair.
|
||||
switch (FaTypeK) {
|
||||
case FA_TYPE_Q4_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q4_0.data[a_offset + global_ib].d), 0.0); break;
|
||||
case FA_TYPE_Q4_1: kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed_q4_1_p32.data[a_offset + global_ib].dm); break;
|
||||
case FA_TYPE_Q5_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q5_0.data[a_offset + global_ib].d), 0.0); break;
|
||||
case FA_TYPE_Q5_1: kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed_q5_1_p32.data[a_offset + global_ib].dm); break;
|
||||
case FA_TYPE_Q8_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q8_0.data[a_offset + global_ib].d), 0.0); break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// d_per_step==8 hot path: read one full 32-element block worth of nibble-packed
|
||||
// int32 quants. Equivalent to 8 calls to get_k_qs(ib, d*4, a_offset) but reads
|
||||
// qh (Q5_*) and runs pack32 (Q4_0/Q5_0) once per block instead of per nibble
|
||||
// quad. iqs is always 0 in this path (hsk4 % 8 == 0 implies block-aligned).
|
||||
// Q8_0 takes the generic get_k_qs path because its qs layout (i8 pairs) doesn't
|
||||
// share this nibble shape.
|
||||
//
|
||||
// Returned via a struct so the caller's k_quants array (sized from spec
|
||||
// constants) doesn't need to match a fixed[8] out-parameter type.
|
||||
struct fa_k_qs_block8 {
|
||||
int32_t qs[8];
|
||||
};
|
||||
|
||||
fa_k_qs_block8 get_k_qs_block8(uint ib, uint a_offset) {
|
||||
fa_k_qs_block8 r;
|
||||
uint qh = 0;
|
||||
if (FaTypeK == FA_TYPE_Q5_0) {
|
||||
qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qh[0],
|
||||
k_packed_q5_0.data[a_offset + ib].qh[1]));
|
||||
} else if (FaTypeK == FA_TYPE_Q5_1) {
|
||||
qh = k_packed_q5_1.data[a_offset + ib].qh;
|
||||
}
|
||||
const bool has_qh = (FaTypeK == FA_TYPE_Q5_0) || (FaTypeK == FA_TYPE_Q5_1);
|
||||
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
|
||||
uint vui = 0;
|
||||
switch (FaTypeK) {
|
||||
case FA_TYPE_Q4_0: { // packed16
|
||||
vui = pack32(u16vec2(k_packed_q4_0.data[a_offset + ib].qs[d * 2 + 0],
|
||||
k_packed_q4_0.data[a_offset + ib].qs[d * 2 + 1]));
|
||||
break;
|
||||
}
|
||||
case FA_TYPE_Q4_1: { // packed32 alias
|
||||
vui = k_packed_q4_1_p32.data[a_offset + ib].qs[d];
|
||||
break;
|
||||
}
|
||||
case FA_TYPE_Q5_0: { // packed16
|
||||
vui = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qs[d * 2 + 0],
|
||||
k_packed_q5_0.data[a_offset + ib].qs[d * 2 + 1]));
|
||||
break;
|
||||
}
|
||||
case FA_TYPE_Q5_1: { // packed32 alias
|
||||
vui = k_packed_q5_1_p32.data[a_offset + ib].qs[d];
|
||||
break;
|
||||
}
|
||||
}
|
||||
r.qs[d ] = int32_t( vui & 0x0F0F0F0F);
|
||||
r.qs[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
|
||||
if (has_qh) {
|
||||
uint qh_lo = (qh >> (d * 4)) & 0xFu;
|
||||
uint qh_hi = (qh >> (d * 4 + 16)) & 0xFu;
|
||||
r.qs[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
|
||||
r.qs[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
|
||||
}
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
int32_t get_k_qs_shmem(const uint buf_ib, const uint pos) {
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
|
||||
uint sub = pos % 4;
|
||||
uint shift = ((pos % 8) >= 4) ? 4 : 0;
|
||||
return int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F);
|
||||
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
uint sub = pos % 4;
|
||||
uint shift = ((pos % 8) >= 4) ? 4 : 0;
|
||||
int32_t result = int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F);
|
||||
uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4)) & 0xF;
|
||||
return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
|
||||
#elif defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
|
||||
return kblocksh[buf_ib].qs[pos];
|
||||
#endif
|
||||
switch (FaTypeK) {
|
||||
case FA_TYPE_Q4_0:
|
||||
case FA_TYPE_Q4_1: {
|
||||
uint sub = pos % 4;
|
||||
uint shift = ((pos % 8) >= 4) ? 4u : 0u;
|
||||
return int32_t((uint(kblocksh[buf_ib].qs[sub]) >> shift) & 0x0F0F0F0Fu);
|
||||
}
|
||||
case FA_TYPE_Q5_0:
|
||||
case FA_TYPE_Q5_1: {
|
||||
uint sub = pos % 4;
|
||||
uint shift = ((pos % 8) >= 4) ? 4u : 0u;
|
||||
int32_t result = int32_t((uint(kblocksh[buf_ib].qs[sub]) >> shift) & 0x0F0F0F0Fu);
|
||||
uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4u)) & 0xFu;
|
||||
return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
|
||||
}
|
||||
case FA_TYPE_Q8_0: {
|
||||
return kblocksh[buf_ib].qs[pos];
|
||||
}
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE k_dot_correction(const uint qib, const ACC_TYPEV2 k_dm) {
|
||||
#if defined(DATA_A_Q4_0)
|
||||
return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
|
||||
#elif defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
||||
return ACC_TYPE(Qf[qib].ds.y) * k_dm.y;
|
||||
#else
|
||||
return ACC_TYPE(0.0);
|
||||
#endif
|
||||
switch (FaTypeK) {
|
||||
case FA_TYPE_Q4_0: return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
|
||||
case FA_TYPE_Q5_0: return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
|
||||
case FA_TYPE_Q4_1:
|
||||
case FA_TYPE_Q5_1: return ACC_TYPE(Qf[qib].ds.y) * k_dm.y;
|
||||
default: return ACC_TYPE(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
void k_block_to_shmem_zero(const uint buf_ib, const uint iqs) {
|
||||
kblocksh[buf_ib].qs[iqs] = 0;
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
kblocksh[buf_ib].qs[iqs + 4] = 0;
|
||||
#endif
|
||||
if (iqs == 0) {
|
||||
#if QUANT_AUXF == 1
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPE(0.0f);
|
||||
#else
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPEV2(0.0f);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,13 @@
|
|||
#if defined(DATA_A_Q4_0)
|
||||
#if defined(FA_MMQ_MIXED)
|
||||
// Mixed-K flash attention MMQ: superset cache that fits Q4_0/Q4_1/Q5_0/Q5_1/Q8_0.
|
||||
// Q4_*/Q5_* only use qs[0..3] and (for Q5_*) qh. Q8_0 uses qs[0..7]. Single-scale
|
||||
// types (Q4_0/Q5_0/Q8_0) leave dm.y unused.
|
||||
struct block_a_cache {
|
||||
int32_t qs[8];
|
||||
uint32_t qh;
|
||||
FLOAT_TYPEV2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q4_0)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[16/4];
|
||||
|
|
|
|||
|
|
@ -643,42 +643,22 @@ void process_shaders() {
|
|||
|
||||
if (fp16) {
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
string_to_spv("flash_attn_f32_f16_mixed", "flash_attn_cm2.comp",
|
||||
string_to_spv("flash_attn_f32_f16", "flash_attn_cm2.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc);
|
||||
#endif
|
||||
}
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
if (tname == "bf16") continue;
|
||||
|
||||
if (fp16) {
|
||||
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
|
||||
}
|
||||
string_to_spv("flash_attn_f32_f16", "flash_attn_cm1.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (tname != "f32") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }, {"MMQ", "1"}}), fp16, false, false, f16acc, "_int8");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
string_to_spv("flash_attn_f32_f16", "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc);
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
string_to_spv("flash_attn_f32_f16", "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"MMQ", "1"}, {"FA_MMQ_MIXED", "1"}}), fp16, false, false, f16acc, "_int8");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue