mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 09:34:37 +00:00
vulkan: support softmax/FA batch and broadcast (#14449)
This commit is contained in:
parent
ec68e84c32
commit
8875523eb3
7 changed files with 80 additions and 44 deletions
|
@ -633,6 +633,7 @@ struct vk_flash_attn_push_constants {
|
|||
uint32_t nev2;
|
||||
uint32_t nev3;
|
||||
uint32_t nem1;
|
||||
uint32_t nem2;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
|
@ -643,7 +644,6 @@ struct vk_flash_attn_push_constants {
|
|||
uint32_t nb21;
|
||||
uint32_t nb22;
|
||||
uint32_t nb23;
|
||||
uint32_t nb31;
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
|
@ -658,6 +658,7 @@ struct vk_flash_attn_push_constants {
|
|||
uint32_t split_kv;
|
||||
uint32_t k_num;
|
||||
};
|
||||
static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");
|
||||
|
||||
struct vk_op_push_constants {
|
||||
uint32_t KX;
|
||||
|
@ -756,6 +757,14 @@ struct vk_op_rope_push_constants {
|
|||
struct vk_op_soft_max_push_constants {
|
||||
uint32_t KX;
|
||||
uint32_t KY;
|
||||
uint32_t ne00;
|
||||
uint32_t ne01;
|
||||
uint32_t ne02;
|
||||
uint32_t ne12;
|
||||
uint32_t ne13;
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
|
@ -6040,7 +6049,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const uint32_t nem1 = mask ? mask->ne[1] : 0;
|
||||
const uint32_t nbm1 = mask ? mask->nb[1] : 0;
|
||||
const uint32_t nem2 = mask ? mask->ne[2] : 0;
|
||||
|
||||
const uint32_t D = neq0;
|
||||
uint32_t N = neq1;
|
||||
|
@ -6203,7 +6212,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
// Try to use split_k when KV is large enough to be worth the overhead
|
||||
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
|
||||
// Try to run two workgroups per SM.
|
||||
split_k = ctx->device->shader_core_count * 2 / workgroups_y;
|
||||
split_k = ctx->device->shader_core_count * 2 / (workgroups_y * workgroups_z);
|
||||
if (split_k > 1) {
|
||||
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
||||
// of "align", so recompute split_k based on that.
|
||||
|
@ -6213,9 +6222,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
}
|
||||
}
|
||||
|
||||
// Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
|
||||
// and the per-row m and L values (ne1 rows).
|
||||
const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
|
||||
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
|
||||
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
|
||||
const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
|
||||
if (split_k_size > ctx->device->max_memory_allocation_size) {
|
||||
GGML_ABORT("Requested preallocation size is too large");
|
||||
}
|
||||
|
@ -6307,11 +6316,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
(uint32_t)neq2, (uint32_t)neq3,
|
||||
(uint32_t)nek2, (uint32_t)nek3,
|
||||
(uint32_t)nev2, (uint32_t)nev3,
|
||||
nem1,
|
||||
nem1, nem2,
|
||||
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
|
||||
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
|
||||
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
|
||||
nbm1,
|
||||
scale, max_bias, logit_softcap,
|
||||
mask != nullptr, n_head_log2, m0, m1,
|
||||
gqa_ratio, split_kv, split_k };
|
||||
|
@ -6334,13 +6342,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
||||
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
|
||||
const std::array<uint32_t, 4> pc2 = { D, (uint32_t)ne1, (uint32_t)ne3, split_k };
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
|
||||
{
|
||||
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
||||
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
||||
},
|
||||
pc2, { (uint32_t)ne1, 1, 1 });
|
||||
pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 });
|
||||
} else {
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{
|
||||
|
@ -7666,7 +7674,13 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||
const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
|
||||
const uint32_t nrows_y = (uint32_t)src0->ne[1];
|
||||
|
||||
const uint32_t n_head_kv = nrows_x/nrows_y;
|
||||
const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;
|
||||
const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;
|
||||
const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;
|
||||
const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;
|
||||
const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;
|
||||
|
||||
const uint32_t n_head_kv = src0->ne[2];
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
|
@ -7675,6 +7689,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
|
||||
ncols,
|
||||
src1 != nullptr ? nrows_y : (uint32_t)0,
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
|
||||
ne12, ne13,
|
||||
nb11, nb12, nb13,
|
||||
scale, max_bias,
|
||||
m0, m1,
|
||||
n_head_log2,
|
||||
|
@ -10248,11 +10265,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
// TODO: support broadcast
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
||||
if (op->src[0]->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
// It's straightforward to support different K/V dequant, but would
|
||||
// significantly increase the number of pipelines
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
|
@ -10413,13 +10425,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
case GGML_OP_DIAG_MASK_INF:
|
||||
return true;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
// TODO: support batching
|
||||
if (op->src[0]->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
// TODO: support broadcast
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
||||
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_SUM:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue