ggml-webgpu: address precision issues for multimodal (#22808)

* fix(mixed-types): use f32 for precision and update the shared memory calculation logic for f32

* fix(unary): correct the gelu, gelu quick and gelu erf functions

* fix(flash-attn-tile): fix the hardcode v type

* fix(flash_attn): fix tile path

* fix: pass editorconfig and address the type conflicts

* fix: remove reduant pipeline keys

* fix: remove inline min/max group size functions and revert the flash attn path order

* fix: use clamp to avoid NaN for GELU

* fix: use the right range for exp, 80 is safer for f32 exp
This commit is contained in:
Chen Yuan 2026-05-12 10:27:04 -04:00 committed by GitHub
parent 89730c8d26
commit 239a497e5f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 295 additions and 186 deletions

View file

@ -91,6 +91,7 @@ struct ggml_webgpu_shader_lib_context {
uint32_t sg_mat_m = 0;
uint32_t sg_mat_n = 0;
uint32_t sg_mat_k = 0;
uint32_t min_subgroup_size = 0;
uint32_t max_subgroup_size = 0;
};
@ -531,7 +532,9 @@ enum ggml_webgpu_flash_attn_path : uint32_t {
};
struct ggml_webgpu_flash_attn_pipeline_key {
ggml_type q_type;
ggml_type kv_type;
ggml_type dst_type;
uint32_t head_dim_qk;
uint32_t head_dim_v;
bool kv_direct;
@ -542,16 +545,19 @@ struct ggml_webgpu_flash_attn_pipeline_key {
uint32_t path;
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask &&
has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap && path == other.path;
return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type &&
head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct &&
kv_overlap == other.kv_overlap && has_mask == other.has_mask && has_sinks == other.has_sinks &&
uses_logit_softcap == other.uses_logit_softcap && path == other.path;
}
};
struct ggml_webgpu_flash_attn_pipeline_key_hash {
size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.q_type);
ggml_webgpu_hash_combine(seed, key.kv_type);
ggml_webgpu_hash_combine(seed, key.dst_type);
ggml_webgpu_hash_combine(seed, key.head_dim_qk);
ggml_webgpu_hash_combine(seed, key.head_dim_v);
ggml_webgpu_hash_combine(seed, key.kv_direct);
@ -595,14 +601,14 @@ inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_
}
inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key(
const ggml_webgpu_shader_lib_context & context,
uint32_t path) {
const ggml_webgpu_shader_lib_context & context,
const ggml_webgpu_flash_attn_decisions & decisions) {
const bool has_mask = context.src3 != nullptr;
const bool has_sinks = context.src4 != nullptr;
bool kv_direct = false;
if (path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
if (decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
uint32_t kv_direct_align = GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH;
if (path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) {
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) {
kv_direct_align = context.sg_mat_k;
}
kv_direct = (context.src1->type == GGML_TYPE_F16) &&
@ -611,7 +617,9 @@ inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_
}
ggml_webgpu_flash_attn_pipeline_key key = {};
key.q_type = context.src0->type;
key.kv_type = context.src1->type;
key.dst_type = context.dst->type;
key.head_dim_qk = (uint32_t) context.src0->ne[0];
key.head_dim_v = (uint32_t) context.src2->ne[0];
key.kv_direct = kv_direct;
@ -619,13 +627,14 @@ inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_
key.has_mask = has_mask;
key.has_sinks = has_sinks;
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
key.path = path;
key.path = decisions.path;
return key;
}
struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key {
uint32_t head_dim_v;
uint32_t wg_size;
uint32_t head_dim_v;
uint32_t wg_size;
ggml_type dst_type;
};
struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash {
@ -633,13 +642,14 @@ struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.head_dim_v);
ggml_webgpu_hash_combine(seed, key.wg_size);
ggml_webgpu_hash_combine(seed, key.dst_type);
return seed;
}
};
inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs,
const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) {
return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size;
return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size && lhs.dst_type == rhs.dst_type;
}
struct ggml_webgpu_flash_attn_blk_pipeline_key {
@ -662,19 +672,32 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
uint32_t head_dim_qk,
uint32_t head_dim_v,
bool has_mask,
bool kv_direct) {
bool kv_direct,
uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) {
const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v);
size_t f16_elems = 0;
size_t f32_elems = 0;
f16_elems += q_tile * head_dim_qk; // q_shmem
if (path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
f32_elems += head_dim_qk; // q_shmem
if (!kv_direct) {
f32_elems += kv_tile * max_head_dim; // kv_shmem
}
f32_elems += head_dim_v; // o_shmem
if (has_mask) {
f32_elems += kv_tile; // mask_shmem
}
f32_elems += kv_tile; // inter_shmem
return f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
}
f32_elems += q_tile * head_dim_qk; // q_shmem
if (!kv_direct) {
f16_elems += kv_tile * max_head_dim; // kv_shmem
f32_elems += kv_tile * max_head_dim; // kv_shmem
}
f16_elems += q_tile * head_dim_v; // o_shmem
f32_elems += q_tile * head_dim_v; // o_shmem
if (has_mask) {
f16_elems += q_tile * kv_tile; // mask_shmem
f32_elems += q_tile * kv_tile; // mask_shmem
}
f16_elems += q_tile * kv_tile; // inter_shmem
f32_elems += q_tile * kv_tile; // inter_shmem
f32_elems += q_tile; // row_max_shmem
f32_elems += q_tile; // exp_sum_shmem
return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
@ -684,27 +707,27 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_
const ggml_webgpu_flash_attn_pipeline_key & key) {
const size_t limit_bytes = context.wg_mem_limit_bytes;
uint32_t q_tile = context.sg_mat_m;
uint32_t kv_granularity = context.sg_mat_n;
uint32_t kv_granularity = std::max(1u, context.sg_mat_n);
if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
q_tile = GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE;
kv_granularity = std::max(1u, context.max_subgroup_size);
kv_granularity = 1u;
} else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
q_tile = 1u;
kv_granularity = 8u;
}
const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
size_t bytes_per_kv = 0;
if (!key.kv_direct) {
bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v);
const size_t base_q_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 0, key.head_dim_qk, key.head_dim_v,
key.has_mask, key.kv_direct, key.path);
if (limit_bytes <= base_q_bytes) {
return 0;
}
if (key.has_mask) {
bytes_per_kv += q_tile;
const size_t one_kv_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 1, key.head_dim_qk, key.head_dim_v,
key.has_mask, key.kv_direct, key.path);
const size_t bytes_per_kv = one_kv_bytes - base_q_bytes;
if (bytes_per_kv == 0) {
return 0;
}
bytes_per_kv += q_tile;
bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
return (max_kv_tile / kv_granularity) * kv_granularity;
const size_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
return (uint32_t) ((max_kv_tile / kv_granularity) * kv_granularity);
}
inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
@ -731,14 +754,18 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
(v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u);
const bool kv_vec_type_supported =
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) &&
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
(context.src2->type == K->type);
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) &&
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
(context.src2->type == K->type);
const bool tile_can_dispatch_all_q_rows =
context.max_subgroup_size > 0 &&
context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size;
const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 &&
V->type == GGML_TYPE_F16 && f16_vec4_aligned &&
(context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && !use_vec;
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
tile_can_dispatch_all_q_rows && !use_vec;
decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC :
use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE :
@ -749,7 +776,7 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
return decisions;
}
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path);
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions);
decisions.kv_direct = key.kv_direct;
const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key);
// invalidate if even the smallest kv_tile doesn't fit in shared memory
@ -778,21 +805,20 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
std::min(64u, max_kv_tile) :
std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ?
GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE :
std::min(std::max(1u, context.max_wg_size),
std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE,
GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size)) :
std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
const uint32_t tile_kv_granularity = std::max(1u, context.max_subgroup_size);
decisions.kv_tile =
std::max(tile_kv_granularity, (decisions.kv_tile / tile_kv_granularity) * tile_kv_granularity);
if (decisions.kv_tile == 0) {
return decisions;
}
if (decisions.kv_direct) {
GGML_ASSERT(decisions.kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) {
decisions.kv_tile -= decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ?
std::max(1u, context.max_subgroup_size) :
context.sg_mat_n;
decisions.kv_tile -=
decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? context.min_subgroup_size : context.sg_mat_n;
}
}
return decisions;
@ -1577,7 +1603,7 @@ class ggml_webgpu_shader_lib {
key.type = context.dst->type;
key.d_state = (int) context.src0->ne[0];
key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
ggml_webgpu_tensor_overlap(context.src1, context.src5);
ggml_webgpu_tensor_overlap(context.src1, context.src5);
auto it = ssm_scan_pipelines.find(key);
if (it != ssm_scan_pipelines.end()) {
@ -1694,10 +1720,10 @@ class ggml_webgpu_shader_lib {
ggml_webgpu_mul_mat_vec_pipeline_key key = {};
key.src0_type = context.src0->type;
key.src1_type = context.src1->type;
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
1 :
0;
auto it = mul_mat_vec_pipelines.find(key);
if (it != mul_mat_vec_pipelines.end()) {
@ -1805,13 +1831,13 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_mul_mat_pipeline_key key = {};
key.src0_type = context.src0->type;
key.src1_type = context.src1->type;
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
key.use_subgroup_matrix = context.supports_subgroup_matrix;
key.src0_type = context.src0->type;
key.src1_type = context.src1->type;
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
key.use_subgroup_matrix = context.supports_subgroup_matrix;
auto it = mul_mat_fast_pipelines.find(key);
if (it != mul_mat_fast_pipelines.end()) {
@ -2074,10 +2100,10 @@ class ggml_webgpu_shader_lib {
key.src0_type = context.src0->type;
key.src1_type = context.src1->type;
key.n_experts = context.src0->ne[2];
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 &&
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 &&
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
1 :
0;
auto it = mul_mat_id_pipelines.find(key);
if (it != mul_mat_id_pipelines.end()) {
@ -2194,10 +2220,10 @@ class ggml_webgpu_shader_lib {
key.src0_type = context.src0->type;
key.src1_type = context.src1->type;
key.n_experts = context.src0->ne[2];
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
1 :
0;
auto it = mul_mat_id_vec_pipelines.find(key);
if (it != mul_mat_id_vec_pipelines.end()) {
@ -2558,7 +2584,7 @@ class ggml_webgpu_shader_lib {
const ggml_webgpu_flash_attn_decisions decisions =
ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment);
GGML_ASSERT(decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_NONE);
ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path);
ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions);
auto it = flash_attn_pipelines.find(key);
if (it != flash_attn_pipelines.end()) {
return it->second;
@ -2586,6 +2612,30 @@ class ggml_webgpu_shader_lib {
}
variant += std::string("_") + ggml_type_name(key.kv_type);
switch (key.q_type) {
case GGML_TYPE_F32:
defines.push_back("Q_F32");
break;
case GGML_TYPE_F16:
defines.push_back("Q_F16");
break;
default:
GGML_ABORT("Unsupported Q type for flash attention shader");
}
variant += std::string("_q") + ggml_type_name(key.q_type);
switch (key.dst_type) {
case GGML_TYPE_F32:
defines.push_back("DST_F32");
break;
case GGML_TYPE_F16:
defines.push_back("DST_F16");
break;
default:
GGML_ABORT("Unsupported dst type for flash attention shader");
}
variant += std::string("_dst") + ggml_type_name(key.dst_type);
if (key.has_mask) {
defines.push_back("MASK");
if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
@ -2625,9 +2675,11 @@ class ggml_webgpu_shader_lib {
shader_src = wgsl_flash_attn_vec_split;
} else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
shader_src = wgsl_flash_attn_tile;
defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size));
defines.push_back("MIN_SUBGROUP_SIZE=" + std::to_string(context.min_subgroup_size) + "u");
defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
defines.push_back("KV_STAGE_STRIDE=" + std::to_string(std::max(key.head_dim_qk, key.head_dim_v)));
variant += "_tile";
variant += "_tile_sg" + std::to_string(context.min_subgroup_size) + "_" +
std::to_string(context.max_subgroup_size);
} else {
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
@ -2677,6 +2729,7 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_flash_attn_vec_reduce_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_flash_attn_vec_reduce_pipeline_key key = {};
key.head_dim_v = (uint32_t) context.src2->ne[0];
key.dst_type = context.dst->type;
key.wg_size = context.max_wg_size;
auto it = flash_attn_vec_reduce_pipelines.find(key);
if (it != flash_attn_vec_reduce_pipelines.end()) {
@ -2686,6 +2739,18 @@ class ggml_webgpu_shader_lib {
std::vector<std::string> defines;
std::string variant = "flash_attn_vec_reduce";
switch (key.dst_type) {
case GGML_TYPE_F32:
defines.push_back("DST_F32");
break;
case GGML_TYPE_F16:
defines.push_back("DST_F16");
break;
default:
GGML_ABORT("Unsupported dst type for flash attention vec reduce shader");
}
variant += std::string("_dst") + ggml_type_name(key.dst_type);
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
variant += std::string("_hsv") + std::to_string(key.head_dim_v);

View file

@ -187,6 +187,7 @@ struct webgpu_capabilities {
uint32_t sg_mat_k = 0;
uint32_t subgroup_size = 0;
uint32_t min_subgroup_size = 0;
uint32_t max_subgroup_size = 0;
size_t memset_bytes_per_thread;
};
@ -1442,6 +1443,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size;
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
// Get or create pipeline
@ -1750,6 +1752,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size;
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(
shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
@ -3469,6 +3472,7 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k;
shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size;
shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size;
const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(
@ -3667,8 +3671,9 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
#endif
ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;
// For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
// Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
// Runtime subgroup size can be any supported size in this range. Shaders
// that allocate per-lane register arrays must size them for the minimum.
ctx->webgpu_global_ctx->capabilities.min_subgroup_size = info.subgroupMinSize;
ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize;
// Initialize device
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
@ -4024,11 +4029,14 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
shader_lib_ctx.dst = const_cast<ggml_tensor *>(op);
shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups;
shader_lib_ctx.supports_subgroup_matrix = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix;
shader_lib_ctx.max_wg_size =
ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.wg_mem_limit_bytes =
ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k;
shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size;
shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size;
const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(
@ -4040,9 +4048,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
break;
}
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
const size_t min_bytes =
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask,
decisions.kv_direct, decisions.path);
if (min_bytes > limit_bytes) {
supports_op = false;
}
@ -4050,9 +4058,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
}
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
const size_t min_bytes =
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask,
decisions.kv_direct, decisions.path);
if (min_bytes > limit_bytes) {
supports_op = false;
}
@ -4063,9 +4071,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
supports_op = false;
break;
}
const size_t min_bytes =
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask,
decisions.kv_direct, decisions.path);
if (min_bytes > limit_bytes) {
supports_op = false;
}

View file

@ -1,12 +1,33 @@
enable f16;
enable subgroups;
#ifdef Q_F16
#define Q_TYPE f16
#else
#define Q_TYPE f32
#endif
#ifdef KV_F32
#define KV_TYPE f32
#else
#define KV_TYPE f16
#endif
#ifdef DST_F16
#define DST_TYPE f16
#else
#define DST_TYPE f32
#endif
#define HEAD_DIM_QK 64
#define HEAD_DIM_V 64
#define KV_STAGE_STRIDE 64
#define Q_TILE 4
#define KV_TILE 64
#define WG_SIZE 128
#ifndef MIN_SUBGROUP_SIZE
#define MIN_SUBGROUP_SIZE MAX_SUBGROUP_SIZE
#endif
struct Params {
offset_q: u32,
@ -41,13 +62,13 @@ struct Params {
m1: f32,
};
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
@group(0) @binding(0) var<storage, read_write> Q: array<Q_TYPE>;
#ifdef KV_OVERLAP
@group(0) @binding(1) var<storage, read_write> K: array<vec4<f16>>;
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
#define V K
#else
@group(0) @binding(1) var<storage, read_write> K: array<vec4<f16>>;
@group(0) @binding(2) var<storage, read_write> V: array<vec4<f16>>;
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>;
#endif
#if defined(MASK) && defined(SINKS)
@ -92,17 +113,17 @@ struct Params {
#endif
#endif
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<DST_TYPE>>;
@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
const FLOAT_MIN: f32 = -1.0e9;
const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u;
const V_CHUNKS: u32 = HEAD_DIM_V / 4u;
const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE;
const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE;
const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE;
const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE;
var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
var<workgroup> kv_shmem: array<f16, KV_TILE * KV_STAGE_STRIDE>;
var<workgroup> q_shmem: array<f32, Q_TILE * HEAD_DIM_QK>;
var<workgroup> kv_shmem: array<f32, KV_TILE * KV_STAGE_STRIDE>;
var<workgroup> p_shmem: array<f32, Q_TILE * KV_TILE>;
@compute @workgroup_size(WG_SIZE)
@ -158,10 +179,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let q_col = elem_idx % HEAD_DIM_QK;
let head_q_row = q_row_start + q_tile_row;
let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
q_shmem[elem_idx] = f16(select(
q_shmem[elem_idx] = select(
0.0,
Q[global_q_row_offset + q_col] * params.scale,
head_q_row < params.seq_len_q));
f32(Q[global_q_row_offset + q_col]) * params.scale,
head_q_row < params.seq_len_q);
}
workgroupBarrier();
@ -192,10 +213,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u;
let k4 = K[k_vec_index];
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
kv_shmem[kv_off + 0u] = k4.x;
kv_shmem[kv_off + 1u] = k4.y;
kv_shmem[kv_off + 2u] = k4.z;
kv_shmem[kv_off + 3u] = k4.w;
kv_shmem[kv_off + 0u] = f32(k4.x);
kv_shmem[kv_off + 1u] = f32(k4.y);
kv_shmem[kv_off + 2u] = f32(k4.z);
kv_shmem[kv_off + 3u] = f32(k4.w);
}
workgroupBarrier();
@ -213,16 +234,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) {
let q_off = q_base + chunk * 4u;
let qv = vec4<f32>(
f32(q_shmem[q_off + 0u]),
f32(q_shmem[q_off + 1u]),
f32(q_shmem[q_off + 2u]),
f32(q_shmem[q_off + 3u]));
q_shmem[q_off + 0u],
q_shmem[q_off + 1u],
q_shmem[q_off + 2u],
q_shmem[q_off + 3u]);
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
let kv = vec4<f32>(
f32(kv_shmem[kv_off + 0u]),
f32(kv_shmem[kv_off + 1u]),
f32(kv_shmem[kv_off + 2u]),
f32(kv_shmem[kv_off + 3u]));
kv_shmem[kv_off + 0u],
kv_shmem[kv_off + 1u],
kv_shmem[kv_off + 2u],
kv_shmem[kv_off + 3u]);
dot_val += dot(qv, kv);
}
#ifdef LOGIT_SOFTCAP
@ -264,10 +285,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u;
let v4 = V[v_vec_index];
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
kv_shmem[kv_off + 0u] = v4.x;
kv_shmem[kv_off + 1u] = v4.y;
kv_shmem[kv_off + 2u] = v4.z;
kv_shmem[kv_off + 3u] = v4.w;
kv_shmem[kv_off + 0u] = f32(v4.x);
kv_shmem[kv_off + 1u] = f32(v4.y);
kv_shmem[kv_off + 2u] = f32(v4.z);
kv_shmem[kv_off + 3u] = f32(v4.w);
}
workgroupBarrier();
@ -288,10 +309,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let p = p_shmem[subgroup_p_offset + kv_local];
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
let v4 = vec4<f32>(
f32(kv_shmem[kv_off + 0u]),
f32(kv_shmem[kv_off + 1u]),
f32(kv_shmem[kv_off + 2u]),
f32(kv_shmem[kv_off + 3u]));
kv_shmem[kv_off + 0u],
kv_shmem[kv_off + 1u],
kv_shmem[kv_off + 2u],
kv_shmem[kv_off + 3u]);
acc += p * v4;
}
out_regs[reg_idx] = acc;
@ -324,7 +345,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
continue;
}
let dst_vec_index = (row_base + chunk * 4u) >> 2u;
dst[dst_vec_index] = out_regs[reg_idx] * inv_exp_sum;
dst[dst_vec_index] = vec4<DST_TYPE>(out_regs[reg_idx] * inv_exp_sum);
}
}
}

View file

@ -2,6 +2,12 @@ diagnostic(off, subgroup_uniformity);
enable f16;
enable subgroups;
#ifdef DST_F16
#define DST_TYPE f16
#else
#define DST_TYPE f32
#endif
// Default values
#define HEAD_DIM_V 64
#define WG_SIZE 128
@ -17,7 +23,7 @@ struct Params {
};
@group(0) @binding(0) var<storage, read_write> tmp: array<f32>;
@group(0) @binding(1) var<storage, read_write> dst: array<vec4<f32>>;
@group(0) @binding(1) var<storage, read_write> dst: array<vec4<DST_TYPE>>;
@group(0) @binding(2) var<uniform> params: Params;
const FLOAT_MIN: f32 = -1.0e9;
@ -72,7 +78,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
if (thread == 0u) {
let dst_vec_index = (row_base + elem_base) >> 2u;
dst[dst_vec_index] = vec4<f32>(sum_x, sum_y, sum_z, sum_w) * inv_s;
dst[dst_vec_index] = vec4<DST_TYPE>(vec4<f32>(sum_x, sum_y, sum_z, sum_w) * inv_s);
}
}
}

View file

@ -8,6 +8,18 @@ enable subgroups;
#define KV_TYPE f16
#endif
#ifdef Q_F16
#define Q_TYPE f16
#else
#define Q_TYPE f32
#endif
#ifdef DST_F16
#define DST_TYPE f16
#else
#define DST_TYPE f32
#endif
#define HEAD_DIM_QK 64
#define HEAD_DIM_V 64
@ -89,7 +101,7 @@ struct Params {
nwg: u32,
};
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
@group(0) @binding(0) var<storage, read_write> Q: array<Q_TYPE>;
#ifdef KV_OVERLAP
#if defined(KV_Q4_0) || defined(KV_Q8_0)
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
@ -191,41 +203,41 @@ struct Params {
@group(0) @binding(BLK_BINDING) var<storage, read_write> blk: array<u32>;
#endif
@group(0) @binding(TMP_BINDING) var<storage, read_write> tmp: array<f32>;
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<DST_TYPE>>;
@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
// Just a very small float value.
const FLOAT_MIN: f32 = -1.0e9;
var<workgroup> q_shmem: array<f16, HEAD_DIM_QK>;
var<workgroup> q_shmem: array<f32, HEAD_DIM_QK>;
#ifndef KV_DIRECT
const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
// we can reuse the same shmem for K and V since we only need one at a time
var<workgroup> kv_shmem: array<f16, kv_shmem_size>;
var<workgroup> kv_shmem: array<f32, kv_shmem_size>;
#endif
var<workgroup> o_shmem: array<f16, HEAD_DIM_V>;
var<workgroup> o_shmem: array<f32, HEAD_DIM_V>;
#ifdef MASK
// storage for mask values
var<workgroup> mask_shmem: array<f16, KV_TILE>;
var<workgroup> mask_shmem: array<f32, KV_TILE>;
#endif
// note that we reuse the same storage for both since we only need one at a time
var<workgroup> inter_shmem: array<f16, KV_TILE>;
var<workgroup> inter_shmem: array<f32, KV_TILE>;
// Storage for row max and exp sum during online softmax
fn calc_softmax_term(kv_idx: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 {
var v = select(FLOAT_MIN,
f32(inter_shmem[kv_idx]) * params.scale,
inter_shmem[kv_idx] * params.scale,
kv_idx < KV_TILE);
#ifdef LOGIT_SOFTCAP
v = params.logit_softcap * tanh(v);
#endif
#ifdef MASK
if (apply_mask) {
var mask_val = select(0.0, f32(mask_shmem[kv_idx]), kv_idx < KV_TILE);
var mask_val = select(0.0, mask_shmem[kv_idx], kv_idx < KV_TILE);
v += select(mask_val, slope * mask_val, has_bias);
}
#endif
@ -289,10 +301,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
// load the single Q row into shared memory
for (var elem_idx = local_id.x; elem_idx < HEAD_DIM_QK; elem_idx += WG_SIZE) {
let global_q_row_offset = q_head_offset + q_row_start * params.stride_q1;
q_shmem[elem_idx] = f16(select(
q_shmem[elem_idx] = select(
0.0,
Q[global_q_row_offset + elem_idx],
q_row_start < params.seq_len_q));
f32(Q[global_q_row_offset + elem_idx]),
q_row_start < params.seq_len_q);
}
for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) {
@ -308,7 +320,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let blk_state = blk_state_local;
let skip_tile = blk_state == 0u;
for (var elem_idx = local_id.x; elem_idx < KV_TILE; elem_idx += WG_SIZE) {
inter_shmem[elem_idx] = f16(0.0);
inter_shmem[elem_idx] = 0.0;
}
// load k tile into shared memory
@ -331,8 +343,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * f32(d);
let q_lo = (f32(q_byte & 0xF) - 8.0) * f32(d);
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_lo;
kv_shmem[row_offset + idx + 16u] = q_hi;
@ -359,7 +371,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;
let q_val = f32(q_byte) * f32(d);
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_val;
}
@ -377,10 +389,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK;
let vec_idx = (global_k_row_offset + k_col) >> 2u;
let k4 = select(vec4<KV_TYPE>(0.0), K[vec_idx], in_bounds);
kv_shmem[elem_idx + 0u] = f16(k4.x);
kv_shmem[elem_idx + 1u] = f16(k4.y);
kv_shmem[elem_idx + 2u] = f16(k4.z);
kv_shmem[elem_idx + 3u] = f16(k4.w);
kv_shmem[elem_idx + 0u] = f32(k4.x);
kv_shmem[elem_idx + 1u] = f32(k4.y);
kv_shmem[elem_idx + 2u] = f32(k4.z);
kv_shmem[elem_idx + 3u] = f32(k4.w);
}
#endif
@ -401,20 +413,20 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let q_off = i * 4u;
let qv = vec4<f32>(
f32(q_shmem[q_off + 0u]),
f32(q_shmem[q_off + 1u]),
f32(q_shmem[q_off + 2u]),
f32(q_shmem[q_off + 3u]));
q_shmem[q_off + 0u],
q_shmem[q_off + 1u],
q_shmem[q_off + 2u],
q_shmem[q_off + 3u]);
#ifdef KV_DIRECT
let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u);
let kv = vec4<f32>(K[idx >> 2u]);
#else
let idx = kv_idx * HEAD_DIM_QK + (i * 4u);
let kv = vec4<f32>(
f32(kv_shmem[idx + 0u]),
f32(kv_shmem[idx + 1u]),
f32(kv_shmem[idx + 2u]),
f32(kv_shmem[idx + 3u]));
kv_shmem[idx + 0u],
kv_shmem[idx + 1u],
kv_shmem[idx + 2u],
kv_shmem[idx + 3u]);
#endif
partial_sum += dot(qv, kv);
}
@ -435,7 +447,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let sum_bcast = subgroupShuffle(sum, num_of_threads * ty);
if (tx == 0u && kv_valid) {
inter_shmem[kv_idx] = f16(sum_bcast);
inter_shmem[kv_idx] = sum_bcast;
}
}
}
@ -450,7 +462,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let global_k_col = kv_tile + elem_idx;
let mask_in_bounds = q_row_start < params.seq_len_q && global_k_col < params.seq_len_kv;
let mask_idx = mask_global_offset + global_k_col;
mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds);
mask_shmem[elem_idx] = select(0.0f, f32(mask[mask_idx]), mask_in_bounds);
}
}
#else
@ -483,7 +495,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE);
total_exp_term += subgroupAdd(cur_p);
if (kv_idx < KV_TILE) {
inter_shmem[kv_idx] = f16(cur_p);
inter_shmem[kv_idx] = cur_p;
}
}
@ -493,7 +505,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
exp_sum = exp_sum * cur_exp + total_exp_term;
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * cur_exp);
o_shmem[elem_idx] = o_shmem[elem_idx] * cur_exp;
}
}
@ -517,8 +529,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * f32(d);
let q_lo = (f32(q_byte & 0xF) - 8.0) * f32(d);
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_lo;
kv_shmem[row_offset + idx + 16u] = q_hi;
@ -545,7 +557,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;
let q_val = f32(q_byte) * f32(d);
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_val;
}
@ -563,10 +575,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V;
let vec_idx = (global_v_row_offset + v_col) >> 2u;
let v4 = select(vec4<KV_TYPE>(0.0), V[vec_idx], in_bounds);
kv_shmem[elem_idx + 0u] = f16(v4.x);
kv_shmem[elem_idx + 1u] = f16(v4.y);
kv_shmem[elem_idx + 2u] = f16(v4.z);
kv_shmem[elem_idx + 3u] = f16(v4.w);
kv_shmem[elem_idx + 0u] = f32(v4.x);
kv_shmem[elem_idx + 1u] = f32(v4.y);
kv_shmem[elem_idx + 2u] = f32(v4.z);
kv_shmem[elem_idx + 3u] = f32(v4.w);
}
#endif
@ -589,17 +601,17 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
continue;
}
let p = f32(inter_shmem[kv_idx]);
let p = inter_shmem[kv_idx];
#ifdef KV_DIRECT
let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u;
let v4 = vec4<f32>(V[v_idx >> 2u]);
#else
let v_idx = kv_idx * HEAD_DIM_V + vec_col * 4u;
let v4 = vec4<f32>(
f32(kv_shmem[v_idx + 0u]),
f32(kv_shmem[v_idx + 1u]),
f32(kv_shmem[v_idx + 2u]),
f32(kv_shmem[v_idx + 3u]));
kv_shmem[v_idx + 0u],
kv_shmem[v_idx + 1u],
kv_shmem[v_idx + 2u],
kv_shmem[v_idx + 3u]);
#endif
lo += p * v4;
}
@ -630,10 +642,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
if (ty_pv == 0u) {
let elem_base = vec_col * 4u;
o_shmem[elem_base + 0u] = f16(f32(o_shmem[elem_base + 0u]) + lo_x);
o_shmem[elem_base + 1u] = f16(f32(o_shmem[elem_base + 1u]) + lo_y);
o_shmem[elem_base + 2u] = f16(f32(o_shmem[elem_base + 2u]) + lo_z);
o_shmem[elem_base + 3u] = f16(f32(o_shmem[elem_base + 3u]) + lo_w);
o_shmem[elem_base + 0u] = o_shmem[elem_base + 0u] + lo_x;
o_shmem[elem_base + 1u] = o_shmem[elem_base + 1u] + lo_y;
o_shmem[elem_base + 2u] = o_shmem[elem_base + 2u] + lo_z;
o_shmem[elem_base + 3u] = o_shmem[elem_base + 3u] + lo_w;
}
}
}
@ -660,7 +672,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
exp_sum = exp_sum * max_exp + sink_exp_sum;
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * max_exp);
o_shmem[elem_idx] = o_shmem[elem_idx] * max_exp;
}
}
workgroupBarrier();
@ -681,7 +693,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
);
let dst_vec_index: u32 = (row_base + elem_base) >> 2u;
dst[dst_vec_index] = v;
dst[dst_vec_index] = vec4<DST_TYPE>(v);
}
} else {
let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + q_row_start;

View file

@ -50,10 +50,25 @@ struct Params {
@group(0) @binding(PARAMS_BINDING)
var<uniform> params: Params;
fn erf_approx(x: TYPE) -> TYPE {
let x_f32 = f32(x);
let s = select(-1.0, 1.0, x_f32 >= 0.0);
let ax = abs(x_f32);
let t = 1.0 / (1.0 + 0.3275911 * ax);
let y = 1.0 -
(((((1.061405429 * t - 1.453152027) * t + 1.421413741) * t
- 0.284496736) * t + 0.254829592) * t) *
exp(-ax * ax);
return TYPE(s * y);
}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
return;
}
var i = gid.x;
let ne2 = params.ne2;
@ -71,15 +86,13 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i1 = i / ne0;
let i0 = i % ne0;
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
i2 * params.stride_src2 + i3 * params.stride_src3;
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + i2 * params.stride_src2 + i3 * params.stride_src3;
#ifdef ABS
let res = abs(src[params.offset_src + src_idx]);
#endif
#ifdef SGN
let res = select(TYPE(select(0.0, -1.0, src[params.offset_src + src_idx] < 0.0)), TYPE(1.0),
src[params.offset_src + src_idx] > 0.0);
let res = select(TYPE(select(0.0, -1.0, src[params.offset_src + src_idx] < 0.0)), TYPE(1.0), src[params.offset_src + src_idx] > 0.0);
#endif
#ifdef NEG
let res = -src[params.offset_src + src_idx];
@ -94,8 +107,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let res = select(0.0, src[params.offset_src + src_idx], src[params.offset_src + src_idx] > 0.0);
#endif
#ifdef ELU
let res = select(exp(src[params.offset_src + src_idx]) - 1.0, src[params.offset_src + src_idx],
src[params.offset_src + src_idx] > 0.0);
let res = select(exp(src[params.offset_src + src_idx]) - 1.0, src[params.offset_src + src_idx], src[params.offset_src + src_idx] > 0.0);
#endif
#ifdef HARDSIGMOID
let res = min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));
@ -120,31 +132,16 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let res = TYPE(params.fill_val);
#endif
#ifdef HARDSWISH
let res = src[params.offset_src + src_idx] *
min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));
let res = src[params.offset_src + src_idx] * min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));
#endif
#ifdef GELU
let res = 0.5 * src[params.offset_src + src_idx] *
(1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) *
(src[params.offset_src + src_idx] +
0.044715 * pow(src[params.offset_src + src_idx], 3.0)),
-9.010913, 9.010913)));
let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + tanh(clamp(0.7978845608028654 * (src[params.offset_src + src_idx] + 0.044715 * src[params.offset_src + src_idx] * src[params.offset_src + src_idx] * src[params.offset_src + src_idx]), -9.010913, 9.010913)));
#endif
#ifdef GELU_QUICK
let res = src[params.offset_src + src_idx] * 0.5 *
(1.0 + tanh(clamp(0.79788456 *
(src[params.offset_src + src_idx] +
0.044715 * src[params.offset_src + src_idx] *
src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),
-9.010913, 9.010913)));
let res = src[params.offset_src + src_idx] * (1.0 / (1.0 + exp(clamp(-1.702 * src[params.offset_src + src_idx], -80.0, 80.0))));
#endif
#ifdef GELU_ERF
let res = 0.5 * src[params.offset_src + src_idx] *
(1.0 + tanh(clamp(0.79788456 *
(src[params.offset_src + src_idx] +
0.044715 * src[params.offset_src + src_idx] *
src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),
-9.010913, 9.010913)));
let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + erf_approx(src[params.offset_src + src_idx] * 0.7071067811865476));
#endif
#ifdef XIELU
let val = f32(src[params.offset_src + src_idx]);