mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-17 04:09:19 +00:00
ggml-webgpu: address precision issues for multimodal (#22808)
* fix(mixed-types): use f32 for precision and update the shared memory calculation logic for f32 * fix(unary): correct the gelu, gelu quick and gelu erf functions * fix(flash-attn-tile): fix the hardcode v type * fix(flash_attn): fix tile path * fix: pass editorconfig and address the type conflicts * fix: remove reduant pipeline keys * fix: remove inline min/max group size functions and revert the flash attn path order * fix: use clamp to avoid NaN for GELU * fix: use the right range for exp, 80 is safer for f32 exp
This commit is contained in:
parent
89730c8d26
commit
239a497e5f
6 changed files with 295 additions and 186 deletions
|
|
@ -91,6 +91,7 @@ struct ggml_webgpu_shader_lib_context {
|
|||
uint32_t sg_mat_m = 0;
|
||||
uint32_t sg_mat_n = 0;
|
||||
uint32_t sg_mat_k = 0;
|
||||
uint32_t min_subgroup_size = 0;
|
||||
uint32_t max_subgroup_size = 0;
|
||||
};
|
||||
|
||||
|
|
@ -531,7 +532,9 @@ enum ggml_webgpu_flash_attn_path : uint32_t {
|
|||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_pipeline_key {
|
||||
ggml_type q_type;
|
||||
ggml_type kv_type;
|
||||
ggml_type dst_type;
|
||||
uint32_t head_dim_qk;
|
||||
uint32_t head_dim_v;
|
||||
bool kv_direct;
|
||||
|
|
@ -542,16 +545,19 @@ struct ggml_webgpu_flash_attn_pipeline_key {
|
|||
uint32_t path;
|
||||
|
||||
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
|
||||
return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
|
||||
kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask &&
|
||||
has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap && path == other.path;
|
||||
return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type &&
|
||||
head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct &&
|
||||
kv_overlap == other.kv_overlap && has_mask == other.has_mask && has_sinks == other.has_sinks &&
|
||||
uses_logit_softcap == other.uses_logit_softcap && path == other.path;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.q_type);
|
||||
ggml_webgpu_hash_combine(seed, key.kv_type);
|
||||
ggml_webgpu_hash_combine(seed, key.dst_type);
|
||||
ggml_webgpu_hash_combine(seed, key.head_dim_qk);
|
||||
ggml_webgpu_hash_combine(seed, key.head_dim_v);
|
||||
ggml_webgpu_hash_combine(seed, key.kv_direct);
|
||||
|
|
@ -595,14 +601,14 @@ inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_
|
|||
}
|
||||
|
||||
inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key(
|
||||
const ggml_webgpu_shader_lib_context & context,
|
||||
uint32_t path) {
|
||||
const ggml_webgpu_shader_lib_context & context,
|
||||
const ggml_webgpu_flash_attn_decisions & decisions) {
|
||||
const bool has_mask = context.src3 != nullptr;
|
||||
const bool has_sinks = context.src4 != nullptr;
|
||||
bool kv_direct = false;
|
||||
if (path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||
if (decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||
uint32_t kv_direct_align = GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH;
|
||||
if (path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) {
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) {
|
||||
kv_direct_align = context.sg_mat_k;
|
||||
}
|
||||
kv_direct = (context.src1->type == GGML_TYPE_F16) &&
|
||||
|
|
@ -611,7 +617,9 @@ inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_
|
|||
}
|
||||
|
||||
ggml_webgpu_flash_attn_pipeline_key key = {};
|
||||
key.q_type = context.src0->type;
|
||||
key.kv_type = context.src1->type;
|
||||
key.dst_type = context.dst->type;
|
||||
key.head_dim_qk = (uint32_t) context.src0->ne[0];
|
||||
key.head_dim_v = (uint32_t) context.src2->ne[0];
|
||||
key.kv_direct = kv_direct;
|
||||
|
|
@ -619,13 +627,14 @@ inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_
|
|||
key.has_mask = has_mask;
|
||||
key.has_sinks = has_sinks;
|
||||
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
|
||||
key.path = path;
|
||||
key.path = decisions.path;
|
||||
return key;
|
||||
}
|
||||
|
||||
struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key {
|
||||
uint32_t head_dim_v;
|
||||
uint32_t wg_size;
|
||||
uint32_t head_dim_v;
|
||||
uint32_t wg_size;
|
||||
ggml_type dst_type;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash {
|
||||
|
|
@ -633,13 +642,14 @@ struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash {
|
|||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.head_dim_v);
|
||||
ggml_webgpu_hash_combine(seed, key.wg_size);
|
||||
ggml_webgpu_hash_combine(seed, key.dst_type);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs,
|
||||
const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) {
|
||||
return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size;
|
||||
return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size && lhs.dst_type == rhs.dst_type;
|
||||
}
|
||||
|
||||
struct ggml_webgpu_flash_attn_blk_pipeline_key {
|
||||
|
|
@ -662,19 +672,32 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
|
|||
uint32_t head_dim_qk,
|
||||
uint32_t head_dim_v,
|
||||
bool has_mask,
|
||||
bool kv_direct) {
|
||||
bool kv_direct,
|
||||
uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) {
|
||||
const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v);
|
||||
size_t f16_elems = 0;
|
||||
size_t f32_elems = 0;
|
||||
f16_elems += q_tile * head_dim_qk; // q_shmem
|
||||
if (path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
f32_elems += head_dim_qk; // q_shmem
|
||||
if (!kv_direct) {
|
||||
f32_elems += kv_tile * max_head_dim; // kv_shmem
|
||||
}
|
||||
f32_elems += head_dim_v; // o_shmem
|
||||
if (has_mask) {
|
||||
f32_elems += kv_tile; // mask_shmem
|
||||
}
|
||||
f32_elems += kv_tile; // inter_shmem
|
||||
return f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||
}
|
||||
f32_elems += q_tile * head_dim_qk; // q_shmem
|
||||
if (!kv_direct) {
|
||||
f16_elems += kv_tile * max_head_dim; // kv_shmem
|
||||
f32_elems += kv_tile * max_head_dim; // kv_shmem
|
||||
}
|
||||
f16_elems += q_tile * head_dim_v; // o_shmem
|
||||
f32_elems += q_tile * head_dim_v; // o_shmem
|
||||
if (has_mask) {
|
||||
f16_elems += q_tile * kv_tile; // mask_shmem
|
||||
f32_elems += q_tile * kv_tile; // mask_shmem
|
||||
}
|
||||
f16_elems += q_tile * kv_tile; // inter_shmem
|
||||
f32_elems += q_tile * kv_tile; // inter_shmem
|
||||
f32_elems += q_tile; // row_max_shmem
|
||||
f32_elems += q_tile; // exp_sum_shmem
|
||||
return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||
|
|
@ -684,27 +707,27 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_
|
|||
const ggml_webgpu_flash_attn_pipeline_key & key) {
|
||||
const size_t limit_bytes = context.wg_mem_limit_bytes;
|
||||
uint32_t q_tile = context.sg_mat_m;
|
||||
uint32_t kv_granularity = context.sg_mat_n;
|
||||
uint32_t kv_granularity = std::max(1u, context.sg_mat_n);
|
||||
if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||
q_tile = GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE;
|
||||
kv_granularity = std::max(1u, context.max_subgroup_size);
|
||||
kv_granularity = 1u;
|
||||
} else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
q_tile = 1u;
|
||||
kv_granularity = 8u;
|
||||
}
|
||||
const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
|
||||
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||
size_t bytes_per_kv = 0;
|
||||
if (!key.kv_direct) {
|
||||
bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v);
|
||||
const size_t base_q_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 0, key.head_dim_qk, key.head_dim_v,
|
||||
key.has_mask, key.kv_direct, key.path);
|
||||
if (limit_bytes <= base_q_bytes) {
|
||||
return 0;
|
||||
}
|
||||
if (key.has_mask) {
|
||||
bytes_per_kv += q_tile;
|
||||
const size_t one_kv_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 1, key.head_dim_qk, key.head_dim_v,
|
||||
key.has_mask, key.kv_direct, key.path);
|
||||
const size_t bytes_per_kv = one_kv_bytes - base_q_bytes;
|
||||
if (bytes_per_kv == 0) {
|
||||
return 0;
|
||||
}
|
||||
bytes_per_kv += q_tile;
|
||||
bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
|
||||
const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
|
||||
return (max_kv_tile / kv_granularity) * kv_granularity;
|
||||
const size_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
|
||||
return (uint32_t) ((max_kv_tile / kv_granularity) * kv_granularity);
|
||||
}
|
||||
|
||||
inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
|
||||
|
|
@ -731,14 +754,18 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
|
|||
(v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u);
|
||||
const bool kv_vec_type_supported =
|
||||
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
|
||||
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) &&
|
||||
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
|
||||
kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
|
||||
(context.src2->type == K->type);
|
||||
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) &&
|
||||
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
|
||||
kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
|
||||
(context.src2->type == K->type);
|
||||
const bool tile_can_dispatch_all_q_rows =
|
||||
context.max_subgroup_size > 0 &&
|
||||
context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size;
|
||||
const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 &&
|
||||
V->type == GGML_TYPE_F16 && f16_vec4_aligned &&
|
||||
(context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
|
||||
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && !use_vec;
|
||||
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
|
||||
tile_can_dispatch_all_q_rows && !use_vec;
|
||||
|
||||
decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC :
|
||||
use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE :
|
||||
|
|
@ -749,7 +776,7 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
|
|||
return decisions;
|
||||
}
|
||||
|
||||
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path);
|
||||
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions);
|
||||
decisions.kv_direct = key.kv_direct;
|
||||
const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key);
|
||||
// invalidate if even the smallest kv_tile doesn't fit in shared memory
|
||||
|
|
@ -778,21 +805,20 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
|
|||
std::min(64u, max_kv_tile) :
|
||||
std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
|
||||
decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ?
|
||||
GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE :
|
||||
std::min(std::max(1u, context.max_wg_size),
|
||||
std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE,
|
||||
GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size)) :
|
||||
std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
|
||||
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||
const uint32_t tile_kv_granularity = std::max(1u, context.max_subgroup_size);
|
||||
decisions.kv_tile =
|
||||
std::max(tile_kv_granularity, (decisions.kv_tile / tile_kv_granularity) * tile_kv_granularity);
|
||||
if (decisions.kv_tile == 0) {
|
||||
return decisions;
|
||||
}
|
||||
|
||||
if (decisions.kv_direct) {
|
||||
GGML_ASSERT(decisions.kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) {
|
||||
decisions.kv_tile -= decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ?
|
||||
std::max(1u, context.max_subgroup_size) :
|
||||
context.sg_mat_n;
|
||||
decisions.kv_tile -=
|
||||
decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? context.min_subgroup_size : context.sg_mat_n;
|
||||
}
|
||||
}
|
||||
return decisions;
|
||||
|
|
@ -1577,7 +1603,7 @@ class ggml_webgpu_shader_lib {
|
|||
key.type = context.dst->type;
|
||||
key.d_state = (int) context.src0->ne[0];
|
||||
key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
|
||||
ggml_webgpu_tensor_overlap(context.src1, context.src5);
|
||||
ggml_webgpu_tensor_overlap(context.src1, context.src5);
|
||||
|
||||
auto it = ssm_scan_pipelines.find(key);
|
||||
if (it != ssm_scan_pipelines.end()) {
|
||||
|
|
@ -1694,10 +1720,10 @@ class ggml_webgpu_shader_lib {
|
|||
ggml_webgpu_mul_mat_vec_pipeline_key key = {};
|
||||
key.src0_type = context.src0->type;
|
||||
key.src1_type = context.src1->type;
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
1 :
|
||||
0;
|
||||
|
||||
auto it = mul_mat_vec_pipelines.find(key);
|
||||
if (it != mul_mat_vec_pipelines.end()) {
|
||||
|
|
@ -1805,13 +1831,13 @@ class ggml_webgpu_shader_lib {
|
|||
|
||||
webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_mul_mat_pipeline_key key = {};
|
||||
key.src0_type = context.src0->type;
|
||||
key.src1_type = context.src1->type;
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
key.use_subgroup_matrix = context.supports_subgroup_matrix;
|
||||
key.src0_type = context.src0->type;
|
||||
key.src1_type = context.src1->type;
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
key.use_subgroup_matrix = context.supports_subgroup_matrix;
|
||||
|
||||
auto it = mul_mat_fast_pipelines.find(key);
|
||||
if (it != mul_mat_fast_pipelines.end()) {
|
||||
|
|
@ -2074,10 +2100,10 @@ class ggml_webgpu_shader_lib {
|
|||
key.src0_type = context.src0->type;
|
||||
key.src1_type = context.src1->type;
|
||||
key.n_experts = context.src0->ne[2];
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 &&
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 &&
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
1 :
|
||||
0;
|
||||
|
||||
auto it = mul_mat_id_pipelines.find(key);
|
||||
if (it != mul_mat_id_pipelines.end()) {
|
||||
|
|
@ -2194,10 +2220,10 @@ class ggml_webgpu_shader_lib {
|
|||
key.src0_type = context.src0->type;
|
||||
key.src1_type = context.src1->type;
|
||||
key.n_experts = context.src0->ne[2];
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
1 :
|
||||
0;
|
||||
|
||||
auto it = mul_mat_id_vec_pipelines.find(key);
|
||||
if (it != mul_mat_id_vec_pipelines.end()) {
|
||||
|
|
@ -2558,7 +2584,7 @@ class ggml_webgpu_shader_lib {
|
|||
const ggml_webgpu_flash_attn_decisions decisions =
|
||||
ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment);
|
||||
GGML_ASSERT(decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_NONE);
|
||||
ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path);
|
||||
ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions);
|
||||
auto it = flash_attn_pipelines.find(key);
|
||||
if (it != flash_attn_pipelines.end()) {
|
||||
return it->second;
|
||||
|
|
@ -2586,6 +2612,30 @@ class ggml_webgpu_shader_lib {
|
|||
}
|
||||
variant += std::string("_") + ggml_type_name(key.kv_type);
|
||||
|
||||
switch (key.q_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("Q_F32");
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("Q_F16");
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported Q type for flash attention shader");
|
||||
}
|
||||
variant += std::string("_q") + ggml_type_name(key.q_type);
|
||||
|
||||
switch (key.dst_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("DST_F32");
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("DST_F16");
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported dst type for flash attention shader");
|
||||
}
|
||||
variant += std::string("_dst") + ggml_type_name(key.dst_type);
|
||||
|
||||
if (key.has_mask) {
|
||||
defines.push_back("MASK");
|
||||
if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
|
|
@ -2625,9 +2675,11 @@ class ggml_webgpu_shader_lib {
|
|||
shader_src = wgsl_flash_attn_vec_split;
|
||||
} else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||
shader_src = wgsl_flash_attn_tile;
|
||||
defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size));
|
||||
defines.push_back("MIN_SUBGROUP_SIZE=" + std::to_string(context.min_subgroup_size) + "u");
|
||||
defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
|
||||
defines.push_back("KV_STAGE_STRIDE=" + std::to_string(std::max(key.head_dim_qk, key.head_dim_v)));
|
||||
variant += "_tile";
|
||||
variant += "_tile_sg" + std::to_string(context.min_subgroup_size) + "_" +
|
||||
std::to_string(context.max_subgroup_size);
|
||||
} else {
|
||||
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
|
||||
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
|
||||
|
|
@ -2677,6 +2729,7 @@ class ggml_webgpu_shader_lib {
|
|||
webgpu_pipeline get_flash_attn_vec_reduce_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_flash_attn_vec_reduce_pipeline_key key = {};
|
||||
key.head_dim_v = (uint32_t) context.src2->ne[0];
|
||||
key.dst_type = context.dst->type;
|
||||
key.wg_size = context.max_wg_size;
|
||||
auto it = flash_attn_vec_reduce_pipelines.find(key);
|
||||
if (it != flash_attn_vec_reduce_pipelines.end()) {
|
||||
|
|
@ -2686,6 +2739,18 @@ class ggml_webgpu_shader_lib {
|
|||
std::vector<std::string> defines;
|
||||
std::string variant = "flash_attn_vec_reduce";
|
||||
|
||||
switch (key.dst_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("DST_F32");
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("DST_F16");
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported dst type for flash attention vec reduce shader");
|
||||
}
|
||||
variant += std::string("_dst") + ggml_type_name(key.dst_type);
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
|
||||
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
|
||||
|
||||
|
|
|
|||
|
|
@ -187,6 +187,7 @@ struct webgpu_capabilities {
|
|||
uint32_t sg_mat_k = 0;
|
||||
|
||||
uint32_t subgroup_size = 0;
|
||||
uint32_t min_subgroup_size = 0;
|
||||
uint32_t max_subgroup_size = 0;
|
||||
size_t memset_bytes_per_thread;
|
||||
};
|
||||
|
|
@ -1442,6 +1443,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
|||
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
|
||||
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
|
||||
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
|
||||
shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size;
|
||||
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
|
||||
|
||||
// Get or create pipeline
|
||||
|
|
@ -1750,6 +1752,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
|||
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
|
||||
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
|
||||
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
|
||||
shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size;
|
||||
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(
|
||||
shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
|
|
@ -3469,6 +3472,7 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
|
|||
shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
|
||||
shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
|
||||
shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k;
|
||||
shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size;
|
||||
shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size;
|
||||
|
||||
const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(
|
||||
|
|
@ -3667,8 +3671,9 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
|||
#endif
|
||||
ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;
|
||||
|
||||
// For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
|
||||
// Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
|
||||
// Runtime subgroup size can be any supported size in this range. Shaders
|
||||
// that allocate per-lane register arrays must size them for the minimum.
|
||||
ctx->webgpu_global_ctx->capabilities.min_subgroup_size = info.subgroupMinSize;
|
||||
ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize;
|
||||
// Initialize device
|
||||
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
|
||||
|
|
@ -4024,11 +4029,14 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
shader_lib_ctx.dst = const_cast<ggml_tensor *>(op);
|
||||
shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups;
|
||||
shader_lib_ctx.supports_subgroup_matrix = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix;
|
||||
shader_lib_ctx.max_wg_size =
|
||||
ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.wg_mem_limit_bytes =
|
||||
ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
|
||||
shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
|
||||
shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k;
|
||||
shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size;
|
||||
shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size;
|
||||
|
||||
const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(
|
||||
|
|
@ -4040,9 +4048,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
break;
|
||||
}
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
const size_t min_bytes =
|
||||
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
|
||||
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
|
||||
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
|
||||
decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask,
|
||||
decisions.kv_direct, decisions.path);
|
||||
if (min_bytes > limit_bytes) {
|
||||
supports_op = false;
|
||||
}
|
||||
|
|
@ -4050,9 +4058,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
}
|
||||
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||
const size_t min_bytes =
|
||||
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
|
||||
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
|
||||
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
|
||||
decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask,
|
||||
decisions.kv_direct, decisions.path);
|
||||
if (min_bytes > limit_bytes) {
|
||||
supports_op = false;
|
||||
}
|
||||
|
|
@ -4063,9 +4071,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
supports_op = false;
|
||||
break;
|
||||
}
|
||||
const size_t min_bytes =
|
||||
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
|
||||
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
|
||||
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
|
||||
decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask,
|
||||
decisions.kv_direct, decisions.path);
|
||||
if (min_bytes > limit_bytes) {
|
||||
supports_op = false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,33 @@
|
|||
enable f16;
|
||||
enable subgroups;
|
||||
|
||||
#ifdef Q_F16
|
||||
#define Q_TYPE f16
|
||||
#else
|
||||
#define Q_TYPE f32
|
||||
#endif
|
||||
|
||||
#ifdef KV_F32
|
||||
#define KV_TYPE f32
|
||||
#else
|
||||
#define KV_TYPE f16
|
||||
#endif
|
||||
|
||||
#ifdef DST_F16
|
||||
#define DST_TYPE f16
|
||||
#else
|
||||
#define DST_TYPE f32
|
||||
#endif
|
||||
|
||||
#define HEAD_DIM_QK 64
|
||||
#define HEAD_DIM_V 64
|
||||
#define KV_STAGE_STRIDE 64
|
||||
#define Q_TILE 4
|
||||
#define KV_TILE 64
|
||||
#define WG_SIZE 128
|
||||
#ifndef MIN_SUBGROUP_SIZE
|
||||
#define MIN_SUBGROUP_SIZE MAX_SUBGROUP_SIZE
|
||||
#endif
|
||||
|
||||
struct Params {
|
||||
offset_q: u32,
|
||||
|
|
@ -41,13 +62,13 @@ struct Params {
|
|||
m1: f32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
|
||||
@group(0) @binding(0) var<storage, read_write> Q: array<Q_TYPE>;
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<vec4<f16>>;
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
|
||||
#define V K
|
||||
#else
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<vec4<f16>>;
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<vec4<f16>>;
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>;
|
||||
#endif
|
||||
|
||||
#if defined(MASK) && defined(SINKS)
|
||||
|
|
@ -92,17 +113,17 @@ struct Params {
|
|||
#endif
|
||||
#endif
|
||||
|
||||
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
|
||||
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<DST_TYPE>>;
|
||||
@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
|
||||
|
||||
const FLOAT_MIN: f32 = -1.0e9;
|
||||
const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u;
|
||||
const V_CHUNKS: u32 = HEAD_DIM_V / 4u;
|
||||
const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE;
|
||||
const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE;
|
||||
const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE;
|
||||
const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE;
|
||||
|
||||
var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
|
||||
var<workgroup> kv_shmem: array<f16, KV_TILE * KV_STAGE_STRIDE>;
|
||||
var<workgroup> q_shmem: array<f32, Q_TILE * HEAD_DIM_QK>;
|
||||
var<workgroup> kv_shmem: array<f32, KV_TILE * KV_STAGE_STRIDE>;
|
||||
var<workgroup> p_shmem: array<f32, Q_TILE * KV_TILE>;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
|
|
@ -158,10 +179,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
let q_col = elem_idx % HEAD_DIM_QK;
|
||||
let head_q_row = q_row_start + q_tile_row;
|
||||
let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
|
||||
q_shmem[elem_idx] = f16(select(
|
||||
q_shmem[elem_idx] = select(
|
||||
0.0,
|
||||
Q[global_q_row_offset + q_col] * params.scale,
|
||||
head_q_row < params.seq_len_q));
|
||||
f32(Q[global_q_row_offset + q_col]) * params.scale,
|
||||
head_q_row < params.seq_len_q);
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
|
@ -192,10 +213,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u;
|
||||
let k4 = K[k_vec_index];
|
||||
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||
kv_shmem[kv_off + 0u] = k4.x;
|
||||
kv_shmem[kv_off + 1u] = k4.y;
|
||||
kv_shmem[kv_off + 2u] = k4.z;
|
||||
kv_shmem[kv_off + 3u] = k4.w;
|
||||
kv_shmem[kv_off + 0u] = f32(k4.x);
|
||||
kv_shmem[kv_off + 1u] = f32(k4.y);
|
||||
kv_shmem[kv_off + 2u] = f32(k4.z);
|
||||
kv_shmem[kv_off + 3u] = f32(k4.w);
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
|
@ -213,16 +234,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) {
|
||||
let q_off = q_base + chunk * 4u;
|
||||
let qv = vec4<f32>(
|
||||
f32(q_shmem[q_off + 0u]),
|
||||
f32(q_shmem[q_off + 1u]),
|
||||
f32(q_shmem[q_off + 2u]),
|
||||
f32(q_shmem[q_off + 3u]));
|
||||
q_shmem[q_off + 0u],
|
||||
q_shmem[q_off + 1u],
|
||||
q_shmem[q_off + 2u],
|
||||
q_shmem[q_off + 3u]);
|
||||
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||
let kv = vec4<f32>(
|
||||
f32(kv_shmem[kv_off + 0u]),
|
||||
f32(kv_shmem[kv_off + 1u]),
|
||||
f32(kv_shmem[kv_off + 2u]),
|
||||
f32(kv_shmem[kv_off + 3u]));
|
||||
kv_shmem[kv_off + 0u],
|
||||
kv_shmem[kv_off + 1u],
|
||||
kv_shmem[kv_off + 2u],
|
||||
kv_shmem[kv_off + 3u]);
|
||||
dot_val += dot(qv, kv);
|
||||
}
|
||||
#ifdef LOGIT_SOFTCAP
|
||||
|
|
@ -264,10 +285,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u;
|
||||
let v4 = V[v_vec_index];
|
||||
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||
kv_shmem[kv_off + 0u] = v4.x;
|
||||
kv_shmem[kv_off + 1u] = v4.y;
|
||||
kv_shmem[kv_off + 2u] = v4.z;
|
||||
kv_shmem[kv_off + 3u] = v4.w;
|
||||
kv_shmem[kv_off + 0u] = f32(v4.x);
|
||||
kv_shmem[kv_off + 1u] = f32(v4.y);
|
||||
kv_shmem[kv_off + 2u] = f32(v4.z);
|
||||
kv_shmem[kv_off + 3u] = f32(v4.w);
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
|
@ -288,10 +309,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
let p = p_shmem[subgroup_p_offset + kv_local];
|
||||
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||
let v4 = vec4<f32>(
|
||||
f32(kv_shmem[kv_off + 0u]),
|
||||
f32(kv_shmem[kv_off + 1u]),
|
||||
f32(kv_shmem[kv_off + 2u]),
|
||||
f32(kv_shmem[kv_off + 3u]));
|
||||
kv_shmem[kv_off + 0u],
|
||||
kv_shmem[kv_off + 1u],
|
||||
kv_shmem[kv_off + 2u],
|
||||
kv_shmem[kv_off + 3u]);
|
||||
acc += p * v4;
|
||||
}
|
||||
out_regs[reg_idx] = acc;
|
||||
|
|
@ -324,7 +345,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
continue;
|
||||
}
|
||||
let dst_vec_index = (row_base + chunk * 4u) >> 2u;
|
||||
dst[dst_vec_index] = out_regs[reg_idx] * inv_exp_sum;
|
||||
dst[dst_vec_index] = vec4<DST_TYPE>(out_regs[reg_idx] * inv_exp_sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,12 @@ diagnostic(off, subgroup_uniformity);
|
|||
enable f16;
|
||||
enable subgroups;
|
||||
|
||||
#ifdef DST_F16
|
||||
#define DST_TYPE f16
|
||||
#else
|
||||
#define DST_TYPE f32
|
||||
#endif
|
||||
|
||||
// Default values
|
||||
#define HEAD_DIM_V 64
|
||||
#define WG_SIZE 128
|
||||
|
|
@ -17,7 +23,7 @@ struct Params {
|
|||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> tmp: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read_write> dst: array<vec4<f32>>;
|
||||
@group(0) @binding(1) var<storage, read_write> dst: array<vec4<DST_TYPE>>;
|
||||
@group(0) @binding(2) var<uniform> params: Params;
|
||||
|
||||
const FLOAT_MIN: f32 = -1.0e9;
|
||||
|
|
@ -72,7 +78,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
|
||||
if (thread == 0u) {
|
||||
let dst_vec_index = (row_base + elem_base) >> 2u;
|
||||
dst[dst_vec_index] = vec4<f32>(sum_x, sum_y, sum_z, sum_w) * inv_s;
|
||||
dst[dst_vec_index] = vec4<DST_TYPE>(vec4<f32>(sum_x, sum_y, sum_z, sum_w) * inv_s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,18 @@ enable subgroups;
|
|||
#define KV_TYPE f16
|
||||
#endif
|
||||
|
||||
#ifdef Q_F16
|
||||
#define Q_TYPE f16
|
||||
#else
|
||||
#define Q_TYPE f32
|
||||
#endif
|
||||
|
||||
#ifdef DST_F16
|
||||
#define DST_TYPE f16
|
||||
#else
|
||||
#define DST_TYPE f32
|
||||
#endif
|
||||
|
||||
#define HEAD_DIM_QK 64
|
||||
#define HEAD_DIM_V 64
|
||||
|
||||
|
|
@ -89,7 +101,7 @@ struct Params {
|
|||
nwg: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
|
||||
@group(0) @binding(0) var<storage, read_write> Q: array<Q_TYPE>;
|
||||
#ifdef KV_OVERLAP
|
||||
#if defined(KV_Q4_0) || defined(KV_Q8_0)
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
||||
|
|
@ -191,41 +203,41 @@ struct Params {
|
|||
@group(0) @binding(BLK_BINDING) var<storage, read_write> blk: array<u32>;
|
||||
#endif
|
||||
@group(0) @binding(TMP_BINDING) var<storage, read_write> tmp: array<f32>;
|
||||
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
|
||||
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<DST_TYPE>>;
|
||||
@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
|
||||
|
||||
// Just a very small float value.
|
||||
const FLOAT_MIN: f32 = -1.0e9;
|
||||
|
||||
var<workgroup> q_shmem: array<f16, HEAD_DIM_QK>;
|
||||
var<workgroup> q_shmem: array<f32, HEAD_DIM_QK>;
|
||||
|
||||
#ifndef KV_DIRECT
|
||||
const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
|
||||
// we can reuse the same shmem for K and V since we only need one at a time
|
||||
var<workgroup> kv_shmem: array<f16, kv_shmem_size>;
|
||||
var<workgroup> kv_shmem: array<f32, kv_shmem_size>;
|
||||
#endif
|
||||
|
||||
var<workgroup> o_shmem: array<f16, HEAD_DIM_V>;
|
||||
var<workgroup> o_shmem: array<f32, HEAD_DIM_V>;
|
||||
|
||||
#ifdef MASK
|
||||
// storage for mask values
|
||||
var<workgroup> mask_shmem: array<f16, KV_TILE>;
|
||||
var<workgroup> mask_shmem: array<f32, KV_TILE>;
|
||||
#endif
|
||||
|
||||
// note that we reuse the same storage for both since we only need one at a time
|
||||
var<workgroup> inter_shmem: array<f16, KV_TILE>;
|
||||
var<workgroup> inter_shmem: array<f32, KV_TILE>;
|
||||
|
||||
// Storage for row max and exp sum during online softmax
|
||||
fn calc_softmax_term(kv_idx: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 {
|
||||
var v = select(FLOAT_MIN,
|
||||
f32(inter_shmem[kv_idx]) * params.scale,
|
||||
inter_shmem[kv_idx] * params.scale,
|
||||
kv_idx < KV_TILE);
|
||||
#ifdef LOGIT_SOFTCAP
|
||||
v = params.logit_softcap * tanh(v);
|
||||
#endif
|
||||
#ifdef MASK
|
||||
if (apply_mask) {
|
||||
var mask_val = select(0.0, f32(mask_shmem[kv_idx]), kv_idx < KV_TILE);
|
||||
var mask_val = select(0.0, mask_shmem[kv_idx], kv_idx < KV_TILE);
|
||||
v += select(mask_val, slope * mask_val, has_bias);
|
||||
}
|
||||
#endif
|
||||
|
|
@ -289,10 +301,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
// load the single Q row into shared memory
|
||||
for (var elem_idx = local_id.x; elem_idx < HEAD_DIM_QK; elem_idx += WG_SIZE) {
|
||||
let global_q_row_offset = q_head_offset + q_row_start * params.stride_q1;
|
||||
q_shmem[elem_idx] = f16(select(
|
||||
q_shmem[elem_idx] = select(
|
||||
0.0,
|
||||
Q[global_q_row_offset + elem_idx],
|
||||
q_row_start < params.seq_len_q));
|
||||
f32(Q[global_q_row_offset + elem_idx]),
|
||||
q_row_start < params.seq_len_q);
|
||||
}
|
||||
|
||||
for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) {
|
||||
|
|
@ -308,7 +320,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
let blk_state = blk_state_local;
|
||||
let skip_tile = blk_state == 0u;
|
||||
for (var elem_idx = local_id.x; elem_idx < KV_TILE; elem_idx += WG_SIZE) {
|
||||
inter_shmem[elem_idx] = f16(0.0);
|
||||
inter_shmem[elem_idx] = 0.0;
|
||||
}
|
||||
|
||||
// load k tile into shared memory
|
||||
|
|
@ -331,8 +343,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
|
||||
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * f32(d);
|
||||
let q_lo = (f32(q_byte & 0xF) - 8.0) * f32(d);
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_lo;
|
||||
kv_shmem[row_offset + idx + 16u] = q_hi;
|
||||
|
|
@ -359,7 +371,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f16(q_byte) * d;
|
||||
let q_val = f32(q_byte) * f32(d);
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_val;
|
||||
}
|
||||
|
|
@ -377,10 +389,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK;
|
||||
let vec_idx = (global_k_row_offset + k_col) >> 2u;
|
||||
let k4 = select(vec4<KV_TYPE>(0.0), K[vec_idx], in_bounds);
|
||||
kv_shmem[elem_idx + 0u] = f16(k4.x);
|
||||
kv_shmem[elem_idx + 1u] = f16(k4.y);
|
||||
kv_shmem[elem_idx + 2u] = f16(k4.z);
|
||||
kv_shmem[elem_idx + 3u] = f16(k4.w);
|
||||
kv_shmem[elem_idx + 0u] = f32(k4.x);
|
||||
kv_shmem[elem_idx + 1u] = f32(k4.y);
|
||||
kv_shmem[elem_idx + 2u] = f32(k4.z);
|
||||
kv_shmem[elem_idx + 3u] = f32(k4.w);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -401,20 +413,20 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
let q_off = i * 4u;
|
||||
|
||||
let qv = vec4<f32>(
|
||||
f32(q_shmem[q_off + 0u]),
|
||||
f32(q_shmem[q_off + 1u]),
|
||||
f32(q_shmem[q_off + 2u]),
|
||||
f32(q_shmem[q_off + 3u]));
|
||||
q_shmem[q_off + 0u],
|
||||
q_shmem[q_off + 1u],
|
||||
q_shmem[q_off + 2u],
|
||||
q_shmem[q_off + 3u]);
|
||||
#ifdef KV_DIRECT
|
||||
let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u);
|
||||
let kv = vec4<f32>(K[idx >> 2u]);
|
||||
#else
|
||||
let idx = kv_idx * HEAD_DIM_QK + (i * 4u);
|
||||
let kv = vec4<f32>(
|
||||
f32(kv_shmem[idx + 0u]),
|
||||
f32(kv_shmem[idx + 1u]),
|
||||
f32(kv_shmem[idx + 2u]),
|
||||
f32(kv_shmem[idx + 3u]));
|
||||
kv_shmem[idx + 0u],
|
||||
kv_shmem[idx + 1u],
|
||||
kv_shmem[idx + 2u],
|
||||
kv_shmem[idx + 3u]);
|
||||
#endif
|
||||
partial_sum += dot(qv, kv);
|
||||
}
|
||||
|
|
@ -435,7 +447,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
|
||||
let sum_bcast = subgroupShuffle(sum, num_of_threads * ty);
|
||||
if (tx == 0u && kv_valid) {
|
||||
inter_shmem[kv_idx] = f16(sum_bcast);
|
||||
inter_shmem[kv_idx] = sum_bcast;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -450,7 +462,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
let global_k_col = kv_tile + elem_idx;
|
||||
let mask_in_bounds = q_row_start < params.seq_len_q && global_k_col < params.seq_len_kv;
|
||||
let mask_idx = mask_global_offset + global_k_col;
|
||||
mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds);
|
||||
mask_shmem[elem_idx] = select(0.0f, f32(mask[mask_idx]), mask_in_bounds);
|
||||
}
|
||||
}
|
||||
#else
|
||||
|
|
@ -483,7 +495,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE);
|
||||
total_exp_term += subgroupAdd(cur_p);
|
||||
if (kv_idx < KV_TILE) {
|
||||
inter_shmem[kv_idx] = f16(cur_p);
|
||||
inter_shmem[kv_idx] = cur_p;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -493,7 +505,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
exp_sum = exp_sum * cur_exp + total_exp_term;
|
||||
|
||||
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
||||
o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * cur_exp);
|
||||
o_shmem[elem_idx] = o_shmem[elem_idx] * cur_exp;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -517,8 +529,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
|
||||
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * f32(d);
|
||||
let q_lo = (f32(q_byte & 0xF) - 8.0) * f32(d);
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_lo;
|
||||
kv_shmem[row_offset + idx + 16u] = q_hi;
|
||||
|
|
@ -545,7 +557,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f16(q_byte) * d;
|
||||
let q_val = f32(q_byte) * f32(d);
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_val;
|
||||
}
|
||||
|
|
@ -563,10 +575,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V;
|
||||
let vec_idx = (global_v_row_offset + v_col) >> 2u;
|
||||
let v4 = select(vec4<KV_TYPE>(0.0), V[vec_idx], in_bounds);
|
||||
kv_shmem[elem_idx + 0u] = f16(v4.x);
|
||||
kv_shmem[elem_idx + 1u] = f16(v4.y);
|
||||
kv_shmem[elem_idx + 2u] = f16(v4.z);
|
||||
kv_shmem[elem_idx + 3u] = f16(v4.w);
|
||||
kv_shmem[elem_idx + 0u] = f32(v4.x);
|
||||
kv_shmem[elem_idx + 1u] = f32(v4.y);
|
||||
kv_shmem[elem_idx + 2u] = f32(v4.z);
|
||||
kv_shmem[elem_idx + 3u] = f32(v4.w);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -589,17 +601,17 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
continue;
|
||||
}
|
||||
|
||||
let p = f32(inter_shmem[kv_idx]);
|
||||
let p = inter_shmem[kv_idx];
|
||||
#ifdef KV_DIRECT
|
||||
let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u;
|
||||
let v4 = vec4<f32>(V[v_idx >> 2u]);
|
||||
#else
|
||||
let v_idx = kv_idx * HEAD_DIM_V + vec_col * 4u;
|
||||
let v4 = vec4<f32>(
|
||||
f32(kv_shmem[v_idx + 0u]),
|
||||
f32(kv_shmem[v_idx + 1u]),
|
||||
f32(kv_shmem[v_idx + 2u]),
|
||||
f32(kv_shmem[v_idx + 3u]));
|
||||
kv_shmem[v_idx + 0u],
|
||||
kv_shmem[v_idx + 1u],
|
||||
kv_shmem[v_idx + 2u],
|
||||
kv_shmem[v_idx + 3u]);
|
||||
#endif
|
||||
lo += p * v4;
|
||||
}
|
||||
|
|
@ -630,10 +642,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
|
||||
if (ty_pv == 0u) {
|
||||
let elem_base = vec_col * 4u;
|
||||
o_shmem[elem_base + 0u] = f16(f32(o_shmem[elem_base + 0u]) + lo_x);
|
||||
o_shmem[elem_base + 1u] = f16(f32(o_shmem[elem_base + 1u]) + lo_y);
|
||||
o_shmem[elem_base + 2u] = f16(f32(o_shmem[elem_base + 2u]) + lo_z);
|
||||
o_shmem[elem_base + 3u] = f16(f32(o_shmem[elem_base + 3u]) + lo_w);
|
||||
o_shmem[elem_base + 0u] = o_shmem[elem_base + 0u] + lo_x;
|
||||
o_shmem[elem_base + 1u] = o_shmem[elem_base + 1u] + lo_y;
|
||||
o_shmem[elem_base + 2u] = o_shmem[elem_base + 2u] + lo_z;
|
||||
o_shmem[elem_base + 3u] = o_shmem[elem_base + 3u] + lo_w;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -660,7 +672,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
exp_sum = exp_sum * max_exp + sink_exp_sum;
|
||||
|
||||
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
||||
o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * max_exp);
|
||||
o_shmem[elem_idx] = o_shmem[elem_idx] * max_exp;
|
||||
}
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
|
@ -681,7 +693,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
);
|
||||
|
||||
let dst_vec_index: u32 = (row_base + elem_base) >> 2u;
|
||||
dst[dst_vec_index] = v;
|
||||
dst[dst_vec_index] = vec4<DST_TYPE>(v);
|
||||
}
|
||||
} else {
|
||||
let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + q_row_start;
|
||||
|
|
|
|||
|
|
@ -50,10 +50,25 @@ struct Params {
|
|||
@group(0) @binding(PARAMS_BINDING)
|
||||
var<uniform> params: Params;
|
||||
|
||||
fn erf_approx(x: TYPE) -> TYPE {
|
||||
let x_f32 = f32(x);
|
||||
let s = select(-1.0, 1.0, x_f32 >= 0.0);
|
||||
let ax = abs(x_f32);
|
||||
|
||||
let t = 1.0 / (1.0 + 0.3275911 * ax);
|
||||
|
||||
let y = 1.0 -
|
||||
(((((1.061405429 * t - 1.453152027) * t + 1.421413741) * t
|
||||
- 0.284496736) * t + 0.254829592) * t) *
|
||||
exp(-ax * ax);
|
||||
|
||||
return TYPE(s * y);
|
||||
}
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne) {
|
||||
return;
|
||||
return;
|
||||
}
|
||||
var i = gid.x;
|
||||
let ne2 = params.ne2;
|
||||
|
|
@ -71,15 +86,13 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||
let i1 = i / ne0;
|
||||
let i0 = i % ne0;
|
||||
|
||||
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
|
||||
i2 * params.stride_src2 + i3 * params.stride_src3;
|
||||
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + i2 * params.stride_src2 + i3 * params.stride_src3;
|
||||
|
||||
#ifdef ABS
|
||||
let res = abs(src[params.offset_src + src_idx]);
|
||||
#endif
|
||||
#ifdef SGN
|
||||
let res = select(TYPE(select(0.0, -1.0, src[params.offset_src + src_idx] < 0.0)), TYPE(1.0),
|
||||
src[params.offset_src + src_idx] > 0.0);
|
||||
let res = select(TYPE(select(0.0, -1.0, src[params.offset_src + src_idx] < 0.0)), TYPE(1.0), src[params.offset_src + src_idx] > 0.0);
|
||||
#endif
|
||||
#ifdef NEG
|
||||
let res = -src[params.offset_src + src_idx];
|
||||
|
|
@ -94,8 +107,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||
let res = select(0.0, src[params.offset_src + src_idx], src[params.offset_src + src_idx] > 0.0);
|
||||
#endif
|
||||
#ifdef ELU
|
||||
let res = select(exp(src[params.offset_src + src_idx]) - 1.0, src[params.offset_src + src_idx],
|
||||
src[params.offset_src + src_idx] > 0.0);
|
||||
let res = select(exp(src[params.offset_src + src_idx]) - 1.0, src[params.offset_src + src_idx], src[params.offset_src + src_idx] > 0.0);
|
||||
#endif
|
||||
#ifdef HARDSIGMOID
|
||||
let res = min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));
|
||||
|
|
@ -120,31 +132,16 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||
let res = TYPE(params.fill_val);
|
||||
#endif
|
||||
#ifdef HARDSWISH
|
||||
let res = src[params.offset_src + src_idx] *
|
||||
min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));
|
||||
let res = src[params.offset_src + src_idx] * min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));
|
||||
#endif
|
||||
#ifdef GELU
|
||||
let res = 0.5 * src[params.offset_src + src_idx] *
|
||||
(1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) *
|
||||
(src[params.offset_src + src_idx] +
|
||||
0.044715 * pow(src[params.offset_src + src_idx], 3.0)),
|
||||
-9.010913, 9.010913)));
|
||||
let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + tanh(clamp(0.7978845608028654 * (src[params.offset_src + src_idx] + 0.044715 * src[params.offset_src + src_idx] * src[params.offset_src + src_idx] * src[params.offset_src + src_idx]), -9.010913, 9.010913)));
|
||||
#endif
|
||||
#ifdef GELU_QUICK
|
||||
let res = src[params.offset_src + src_idx] * 0.5 *
|
||||
(1.0 + tanh(clamp(0.79788456 *
|
||||
(src[params.offset_src + src_idx] +
|
||||
0.044715 * src[params.offset_src + src_idx] *
|
||||
src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),
|
||||
-9.010913, 9.010913)));
|
||||
let res = src[params.offset_src + src_idx] * (1.0 / (1.0 + exp(clamp(-1.702 * src[params.offset_src + src_idx], -80.0, 80.0))));
|
||||
#endif
|
||||
#ifdef GELU_ERF
|
||||
let res = 0.5 * src[params.offset_src + src_idx] *
|
||||
(1.0 + tanh(clamp(0.79788456 *
|
||||
(src[params.offset_src + src_idx] +
|
||||
0.044715 * src[params.offset_src + src_idx] *
|
||||
src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),
|
||||
-9.010913, 9.010913)));
|
||||
let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + erf_approx(src[params.offset_src + src_idx] * 0.7071067811865476));
|
||||
#endif
|
||||
#ifdef XIELU
|
||||
let val = f32(src[params.offset_src + src_idx]);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue