mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 09:34:37 +00:00
vulkan: Handle updated FA dim2/3 definition (#14518)
* vulkan: Handle updated FA dim2/3 definition Pack mask boolean and n_head_log2 into a single dword to keep the push constant block under the 128B limit. * handle null mask for gqa * allow gqa with dim3>1
This commit is contained in:
parent
ddef99522d
commit
a0374a67e2
5 changed files with 26 additions and 24 deletions
|
@ -636,6 +636,7 @@ struct vk_flash_attn_push_constants {
|
|||
uint32_t nev3;
|
||||
uint32_t nem1;
|
||||
uint32_t nem2;
|
||||
uint32_t nem3;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
|
@ -651,8 +652,7 @@ struct vk_flash_attn_push_constants {
|
|||
float max_bias;
|
||||
float logit_softcap;
|
||||
|
||||
uint32_t mask;
|
||||
uint32_t n_head_log2;
|
||||
uint32_t mask_n_head_log2;
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
|
@ -6111,6 +6111,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
|
||||
const uint32_t nem1 = mask ? mask->ne[1] : 0;
|
||||
const uint32_t nem2 = mask ? mask->ne[2] : 0;
|
||||
const uint32_t nem3 = mask ? mask->ne[3] : 0;
|
||||
|
||||
const uint32_t HSK = nek0;
|
||||
const uint32_t HSV = nev0;
|
||||
|
@ -6178,7 +6179,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
}
|
||||
|
||||
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
||||
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
|
||||
qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
|
||||
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
|
||||
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
|
||||
// and change addressing calculations to index Q's dimension 2.
|
||||
|
@ -6348,17 +6349,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
}
|
||||
}
|
||||
|
||||
uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
|
||||
|
||||
const vk_flash_attn_push_constants pc = { N, KV,
|
||||
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
|
||||
(uint32_t)neq2, (uint32_t)neq3,
|
||||
(uint32_t)nek2, (uint32_t)nek3,
|
||||
(uint32_t)nev2, (uint32_t)nev3,
|
||||
nem1, nem2,
|
||||
nem1, nem2, nem3,
|
||||
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
|
||||
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
|
||||
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
|
||||
scale, max_bias, logit_softcap,
|
||||
mask != nullptr, n_head_log2, m0, m1,
|
||||
mask_n_head_log2, m0, m1,
|
||||
gqa_ratio, split_kv, split_k };
|
||||
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
|
@ -10303,12 +10306,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
// TODO: support broadcast
|
||||
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
|
||||
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
|
||||
if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
|
||||
return false;
|
||||
}
|
||||
// It's straightforward to support different K/V dequant, but would
|
||||
// significantly increase the number of pipelines
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue