mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-04-28 03:30:20 +00:00
ggml-webgpu: support for SSM_SCAN and disable set_rows error checking (#22327)
* Implement ssm_scan * Remove blocking in graph_compute and check for set rows * Fix bindings * Update op support
This commit is contained in:
parent
0adede866d
commit
dd2914dc81
5 changed files with 4168 additions and 1669 deletions
|
|
@ -26,7 +26,7 @@ Legend:
|
|||
| CLAMP | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
|
|
@ -60,7 +60,7 @@ Legend:
|
|||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
|
|
@ -105,7 +105,7 @@ Legend:
|
|||
| SQR | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SUM | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
|
|
|
|||
5501
docs/ops/WebGPU.csv
5501
docs/ops/WebGPU.csv
File diff suppressed because it is too large
Load diff
|
|
@ -98,6 +98,29 @@ struct ggml_webgpu_ssm_conv_shader_decisions {
|
|||
uint32_t tokens_per_wg;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_ssm_scan_pipeline_key {
|
||||
int type;
|
||||
int d_state;
|
||||
|
||||
bool operator==(const ggml_webgpu_ssm_scan_pipeline_key & other) const {
|
||||
return type == other.type && d_state == other.d_state;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_ssm_scan_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_ssm_scan_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.d_state);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_ssm_scan_shader_decisions {
|
||||
uint32_t wg_size;
|
||||
uint32_t tokens_per_tile;
|
||||
};
|
||||
|
||||
/** Argsort **/
|
||||
|
||||
struct ggml_webgpu_argsort_shader_lib_context {
|
||||
|
|
@ -921,6 +944,8 @@ class ggml_webgpu_shader_lib {
|
|||
solve_tri_pipelines; // type
|
||||
std::unordered_map<ggml_webgpu_ssm_conv_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_conv_pipeline_key_hash>
|
||||
ssm_conv_pipelines; // type/vectorized
|
||||
std::unordered_map<ggml_webgpu_ssm_scan_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_scan_pipeline_key_hash>
|
||||
ssm_scan_pipelines; // type/d_state
|
||||
std::unordered_map<ggml_webgpu_gated_delta_net_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_gated_delta_net_pipeline_key_hash>
|
||||
|
|
@ -1433,6 +1458,53 @@ class ggml_webgpu_shader_lib {
|
|||
return ssm_conv_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_ssm_scan_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_ssm_scan_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
key.d_state = (int) context.src0->ne[0];
|
||||
|
||||
auto it = ssm_scan_pipelines.find(key);
|
||||
if (it != ssm_scan_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "ssm_scan";
|
||||
|
||||
switch (key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
variant += "_f32";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for ssm_scan shader");
|
||||
}
|
||||
|
||||
const uint32_t wg_size = (uint32_t) key.d_state;
|
||||
|
||||
constexpr uint32_t tokens_per_tile = 4u;
|
||||
|
||||
defines.push_back("WG_SIZE=" + std::to_string(wg_size) + "u");
|
||||
defines.push_back("TOKENS_PER_TILE=" + std::to_string(tokens_per_tile) + "u");
|
||||
|
||||
if (context.supports_subgroups) {
|
||||
defines.push_back("USE_SUBGROUP_REDUCTION");
|
||||
variant += "_sg_reduce";
|
||||
} else {
|
||||
variant += "_wg_reduce";
|
||||
}
|
||||
|
||||
variant += "_d" + std::to_string(key.d_state);
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_ssm_scan, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_ssm_scan_shader_decisions>();
|
||||
decisions->wg_size = wg_size;
|
||||
decisions->tokens_per_tile = tokens_per_tile;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
ssm_scan_pipelines[key] = pipeline;
|
||||
return ssm_scan_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_gated_delta_net_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
|
|
|
|||
|
|
@ -1115,6 +1115,80 @@ static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx,
|
|||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_ssm_scan(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * src2,
|
||||
ggml_tensor * src3,
|
||||
ggml_tensor * src4,
|
||||
ggml_tensor * src5,
|
||||
ggml_tensor * src6,
|
||||
ggml_tensor * dst) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src0;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_scan_pipeline(shader_lib_ctx);
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src3) / ggml_type_size(src3->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src4) / ggml_type_size(src4->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src5) / ggml_type_size(src5->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src6) / ggml_type_size(src6->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||
|
||||
(uint32_t) (src2->nb[1] / ggml_type_size(src2->type)),
|
||||
(uint32_t) (src2->nb[2] / ggml_type_size(src2->type)),
|
||||
|
||||
(uint32_t) src3->ne[0],
|
||||
(uint32_t) (src3->nb[1] / ggml_type_size(src3->type)),
|
||||
|
||||
(uint32_t) (src4->nb[1] / ggml_type_size(src4->type)),
|
||||
(uint32_t) (src4->nb[2] / ggml_type_size(src4->type)),
|
||||
(uint32_t) (src4->nb[3] / ggml_type_size(src4->type)),
|
||||
|
||||
(uint32_t) (src5->nb[1] / ggml_type_size(src5->type)),
|
||||
(uint32_t) (src5->nb[2] / ggml_type_size(src5->type)),
|
||||
(uint32_t) (src5->nb[3] / ggml_type_size(src5->type)),
|
||||
|
||||
(uint32_t) src0->ne[0],
|
||||
(uint32_t) src0->ne[1],
|
||||
(uint32_t) src0->ne[2],
|
||||
(uint32_t) src4->ne[1],
|
||||
(uint32_t) src1->ne[2],
|
||||
(uint32_t) src1->ne[3],
|
||||
(uint32_t) ggml_nelements(src1),
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2), ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4), ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, src6), ggml_webgpu_make_tensor_bind_group_entry(ctx, 7, dst),
|
||||
};
|
||||
|
||||
const uint32_t total_wg = (uint32_t) (src0->ne[1] * src0->ne[2] * src1->ne[3]);
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
uint32_t wg_x;
|
||||
uint32_t wg_y;
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
|
|
@ -2764,6 +2838,9 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
|
|||
return ggml_webgpu_solve_tri(ctx, src0, src1, node);
|
||||
case GGML_OP_SSM_CONV:
|
||||
return ggml_webgpu_ssm_conv(ctx, src0, src1, node);
|
||||
case GGML_OP_SSM_SCAN:
|
||||
return ggml_webgpu_ssm_scan(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node->src[6],
|
||||
node);
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node);
|
||||
case GGML_OP_PAD:
|
||||
|
|
@ -2822,7 +2899,10 @@ static void ggml_backend_webgpu_collect_profile_results(webgpu_context &
|
|||
}
|
||||
#endif
|
||||
|
||||
// Don't bother checking set_rows index overflow for now, since practically the WebGPU doesn't need to support
|
||||
// models that would require it right now.
|
||||
static void ggml_backend_webgpu_check_set_rows(webgpu_context & ctx, uint32_t & num_inflight_batches) {
|
||||
#ifdef GGML_WEBGPU_CHECK_SET_ROWS
|
||||
wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder();
|
||||
encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0,
|
||||
ctx->set_rows_host_error_buf.GetSize());
|
||||
|
|
@ -2835,6 +2915,10 @@ static void ggml_backend_webgpu_check_set_rows(webgpu_context & ctx, uint32_t &
|
|||
GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
|
||||
}
|
||||
ctx->set_rows_host_error_buf.Unmap();
|
||||
#else
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(num_inflight_batches);
|
||||
#endif
|
||||
}
|
||||
|
||||
static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
||||
|
|
@ -2920,8 +3004,6 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
|
|||
ggml_backend_webgpu_check_set_rows(ctx, num_inflight_batches);
|
||||
}
|
||||
|
||||
ggml_backend_webgpu_wait_queue(ctx->global_ctx);
|
||||
|
||||
WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
|
@ -3941,6 +4023,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
case GGML_OP_SSM_CONV:
|
||||
supports_op = op->type == GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_OP_SSM_SCAN:
|
||||
supports_op = op->type == GGML_TYPE_F32 &&
|
||||
src0->ne[0] <= ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
break;
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
{
|
||||
const uint32_t s_v = (uint32_t) src2->ne[0];
|
||||
|
|
|
|||
168
ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl
Normal file
168
ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
enable subgroups;
|
||||
#endif
|
||||
|
||||
struct Params {
|
||||
offset_s: u32,
|
||||
offset_x: u32,
|
||||
offset_dt: u32,
|
||||
offset_A: u32,
|
||||
offset_B: u32,
|
||||
offset_C: u32,
|
||||
offset_ids: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
stride_s1: u32,
|
||||
stride_s2: u32,
|
||||
stride_s3: u32,
|
||||
|
||||
stride_x1: u32,
|
||||
stride_x2: u32,
|
||||
stride_x3: u32,
|
||||
|
||||
stride_dt1: u32,
|
||||
stride_dt2: u32,
|
||||
|
||||
a_ne0: u32,
|
||||
stride_A1: u32,
|
||||
|
||||
stride_B1: u32,
|
||||
stride_B2: u32,
|
||||
stride_B3: u32,
|
||||
|
||||
stride_C1: u32,
|
||||
stride_C2: u32,
|
||||
stride_C3: u32,
|
||||
|
||||
d_state: u32,
|
||||
d_inner: u32,
|
||||
n_head: u32,
|
||||
n_group: u32,
|
||||
n_seq_tokens: u32,
|
||||
n_seqs: u32,
|
||||
|
||||
y_elems: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s_in: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read_write> x: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> dt: array<f32>;
|
||||
@group(0) @binding(3) var<storage, read_write> A: array<f32>;
|
||||
@group(0) @binding(4) var<storage, read_write> B: array<f32>;
|
||||
@group(0) @binding(5) var<storage, read_write> C: array<f32>;
|
||||
@group(0) @binding(6) var<storage, read_write> ids: array<i32>;
|
||||
@group(0) @binding(7) var<storage, read_write> dst: array<f32>;
|
||||
@group(0) @binding(8) var<uniform> params: Params;
|
||||
|
||||
var<workgroup> shared_x_dt: array<f32, TOKENS_PER_TILE>;
|
||||
var<workgroup> shared_dtsp: array<f32, TOKENS_PER_TILE>;
|
||||
var<workgroup> shared_reduce: array<f32, TOKENS_PER_TILE * WG_SIZE>;
|
||||
|
||||
fn reduce_base(token_in_tile: u32) -> u32 {
|
||||
return token_in_tile * WG_SIZE;
|
||||
}
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
, @builtin(subgroup_id) subgroup_id: u32,
|
||||
@builtin(subgroup_invocation_id) subgroup_invocation_id: u32,
|
||||
@builtin(num_subgroups) num_subgroups: u32
|
||||
#endif
|
||||
) {
|
||||
let tid = local_id.x;
|
||||
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
|
||||
|
||||
let i1 = wg_linear % params.d_inner;
|
||||
let head_seq = wg_linear / params.d_inner;
|
||||
let ir = head_seq % params.n_head;
|
||||
let i3 = head_seq / params.n_head;
|
||||
|
||||
let state_slot = u32(ids[params.offset_ids + i3]);
|
||||
let g = ir / (params.n_head / params.n_group);
|
||||
|
||||
let s_idx = params.offset_s + tid + i1 * params.stride_s1 + ir * params.stride_s2 + state_slot * params.stride_s3;
|
||||
var s_prev = s_in[s_idx];
|
||||
|
||||
let A0 = A[params.offset_A + (tid % params.a_ne0) + ir * params.stride_A1];
|
||||
|
||||
for (var token_base = 0u; token_base < params.n_seq_tokens; token_base += TOKENS_PER_TILE) {
|
||||
if (tid < TOKENS_PER_TILE) {
|
||||
let token = token_base + tid;
|
||||
if (token < params.n_seq_tokens) {
|
||||
let x_idx = params.offset_x + i1 + ir * params.stride_x1 + token * params.stride_x2 + i3 * params.stride_x3;
|
||||
let dt_idx = params.offset_dt + ir + token * params.stride_dt1 + i3 * params.stride_dt2;
|
||||
let dt0 = dt[dt_idx];
|
||||
let dtsp = select(log(1.0 + exp(dt0)), dt0, dt0 > 20.0);
|
||||
shared_dtsp[tid] = dtsp;
|
||||
shared_x_dt[tid] = x[x_idx] * dtsp;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
for (var token_in_tile = 0u; token_in_tile < TOKENS_PER_TILE; token_in_tile++) {
|
||||
let token = token_base + token_in_tile;
|
||||
if (token >= params.n_seq_tokens) {
|
||||
break;
|
||||
}
|
||||
|
||||
let x_dt = shared_x_dt[token_in_tile];
|
||||
let dA = exp(shared_dtsp[token_in_tile] * A0);
|
||||
let reduce_idx = reduce_base(token_in_tile) + tid;
|
||||
|
||||
let b_idx = params.offset_B + tid + g * params.stride_B1 + token * params.stride_B2 + i3 * params.stride_B3;
|
||||
let c_idx = params.offset_C + tid + g * params.stride_C1 + token * params.stride_C2 + i3 * params.stride_C3;
|
||||
let s = s_prev * dA + B[b_idx] * x_dt;
|
||||
s_prev = s;
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
let subgroup_partial = subgroupAdd(s * C[c_idx]);
|
||||
if (subgroup_invocation_id == 0u) {
|
||||
shared_reduce[reduce_idx - tid + subgroup_id] = subgroup_partial;
|
||||
}
|
||||
#else
|
||||
shared_reduce[reduce_idx] = s * C[c_idx];
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
if (tid == 0u) {
|
||||
var sum = 0.0;
|
||||
for (var sg = 0u; sg < num_subgroups; sg++) {
|
||||
sum += shared_reduce[reduce_base(token_in_tile) + sg];
|
||||
}
|
||||
let y_idx =
|
||||
params.offset_dst + i1 + ir * params.d_inner + token * (params.n_head * params.d_inner) +
|
||||
i3 * (params.n_seq_tokens * params.n_head * params.d_inner);
|
||||
dst[y_idx] = sum;
|
||||
}
|
||||
#else
|
||||
for (var stride = WG_SIZE / 2u; stride > 0u; stride >>= 1u) {
|
||||
if (tid < stride) {
|
||||
shared_reduce[reduce_idx] += shared_reduce[reduce_idx + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
if (tid == 0u) {
|
||||
let y_idx =
|
||||
params.offset_dst + i1 + ir * params.d_inner + token * (params.n_head * params.d_inner) +
|
||||
i3 * (params.n_seq_tokens * params.n_head * params.d_inner);
|
||||
dst[y_idx] = shared_reduce[reduce_base(token_in_tile)];
|
||||
}
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
}
|
||||
|
||||
let state_idx =
|
||||
params.offset_dst + params.y_elems + tid + i1 * params.d_state + ir * (params.d_state * params.d_inner) +
|
||||
i3 * (params.d_state * params.d_inner * params.n_head);
|
||||
dst[state_idx] = s_prev;
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue