sycl: add FILL, CUMSUM, DIAG, SOLVE_TRI, SSM_SCAN, GATED_DELTA_NET (#22149)

* sycl: add FILL, CUMSUM, DIAG, SOLVE_TRI, SSM_SCAN, GATED_DELTA_NET

Signed-off-by: Chun Tao <chun.tao@intel.com>

* Fix abort during test-backend-ops

Signed-off-by: Todd Malsbary <todd.malsbary@intel.com>

* Regenerate ops.md

Signed-off-by: Todd Malsbary <todd.malsbary@intel.com>

* Add scope_dbg_print to newly added SYCL ops.

Also add scope_dbg_print to existing ssm_conv op.

Signed-off-by: Todd Malsbary <todd.malsbary@intel.com>

---------

Signed-off-by: Chun Tao <chun.tao@intel.com>
Signed-off-by: Todd Malsbary <todd.malsbary@intel.com>
Co-authored-by: Chun Tao <chun.tao@intel.com>
Co-authored-by: Todd Malsbary <todd.malsbary@intel.com>
This commit is contained in:
Intel AI Get-to Market Customer Success and Solutions 2026-05-07 08:51:33 -07:00 committed by GitHub
parent b9afc19cb4
commit ad09224658
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 6871 additions and 4113 deletions

View file

@ -17,7 +17,7 @@ Legend:
| ABS | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ✅ | ❌ | ❌ | ❌ |
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ✅ | ❌ | ❌ | ❌ |
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
@ -36,15 +36,15 @@ Legend:
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | | ✅ | ✅ | ❌ | ❌ |
| DIAG | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | | ✅ | ✅ | ❌ | ❌ |
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | | ✅ | ✅ | ❌ | ❌ |
| DIAG | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | | ✅ | ✅ | ❌ | ❌ |
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| ELU | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | | ✅ | ✅ | ❌ | ❌ |
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | | ✅ | ✅ | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
@ -101,11 +101,11 @@ Legend:
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | | ✅ | ✅ | ❌ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | | 🟡 | ✅ | ❌ | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SUM | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | 🟡 | ❌ | ❌ |

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,148 @@
#include "cumsum.hpp"
#include "common.hpp"
#include <algorithm>
#define SYCL_CUMSUM_BLOCK_SIZE 256
static __dpct_inline__ float warp_prefix_inclusive_sum_f32(float x, const sycl::nd_item<3> & item) {
return sycl::inclusive_scan_over_group(item.get_sub_group(), x, sycl::plus<float>());
}
static void cumsum_f32_kernel(
const float * __restrict__ src, float * __restrict__ dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t d1, const int64_t d2, const int64_t d3,
const sycl::nd_item<3> & item, float * smem) {
const int tid = item.get_local_id(2);
const int block_size = item.get_local_range(2);
const int lane = tid % WARP_SIZE;
const int warp = tid / WARP_SIZE;
const int warps_per_block = block_size / WARP_SIZE;
float * s_vals = smem;
float * s_warp_sums = smem + block_size;
float * s_carry = smem + block_size + warps_per_block;
if (tid == 0) {
s_carry[0] = 0.0f;
}
item.barrier(sycl::access::fence_space::local_space);
const int64_t i3 = item.get_group(0);
const int64_t i2 = item.get_group(1);
const int64_t i1 = item.get_group(2);
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
const float * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
float * dst_row = dst + i1 * d1 + i2 * d2 + i3 * d3;
constexpr int num_unroll = 4;
float temp[num_unroll];
for (int64_t i = 0; i < ne00; i += num_unroll * block_size) {
int64_t idx = i + tid * num_unroll;
temp[0] = (idx < ne00 ? src_row[idx] : 0.0f);
#pragma unroll
for (int j = 1; j < num_unroll; j++) {
temp[j] = temp[j - 1];
if (idx + j < ne00) {
temp[j] += src_row[idx + j];
}
}
float val = (idx < ne00) ? temp[num_unroll - 1] : 0.0f;
val = warp_prefix_inclusive_sum_f32(val, item);
s_vals[tid] = val;
if (lane == WARP_SIZE - 1) {
s_warp_sums[warp] = val;
}
item.barrier(sycl::access::fence_space::local_space);
if (warp == 0) {
float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
float inc = warp_prefix_inclusive_sum_f32(w, item);
if (tid < warps_per_block) {
s_warp_sums[tid] = inc - w;
}
if (tid == warps_per_block - 1) {
s_carry[1] = inc;
}
}
item.barrier(sycl::access::fence_space::local_space);
float carry = s_carry[0];
float final_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1];
#pragma unroll
for (int j = 0; j < num_unroll; j++) {
if (idx + j < ne00) {
dst_row[idx + j] = temp[j] + final_offset;
}
}
item.barrier(sycl::access::fence_space::local_space);
if (tid == 0) {
s_carry[0] += s_carry[1];
}
}
}
inline void ggml_sycl_op_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
dpct::queue_ptr stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src_d = static_cast<const float *>(src0->data);
float * dst_d = static_cast<float *>(dst->data);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const size_t ts = sizeof(float);
const int64_t s01 = src0->nb[1] / ts;
const int64_t s02 = src0->nb[2] / ts;
const int64_t s03 = src0->nb[3] / ts;
const int64_t d1 = dst->nb[1] / ts;
const int64_t d2 = dst->nb[2] / ts;
const int64_t d3 = dst->nb[3] / ts;
const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE;
int block_size = num_warps * WARP_SIZE;
block_size = std::min(block_size, SYCL_CUMSUM_BLOCK_SIZE);
const int warps_per_block = block_size / WARP_SIZE;
const int smem_size = block_size + warps_per_block + 2;
const sycl::range<3> grid(ne03, ne02, ne01);
const sycl::range<3> block(1, 1, block_size);
stream->submit([&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> smem_acc(sycl::range<1>(smem_size), cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid * block, block),
[=](sycl::nd_item<3> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
cumsum_f32_kernel(src_d, dst_d, ne00, ne01, ne02, ne03,
s01, s02, s03, d1, d2, d3,
item, get_pointer(smem_acc));
});
});
}
void ggml_sycl_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_cumsum(ctx, dst);
}

View file

@ -0,0 +1,5 @@
#pragma once
#include "common.hpp"
void ggml_sycl_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

View file

@ -0,0 +1,67 @@
#include "diag.hpp"
#include "common.hpp"
#define SYCL_DIAG_BLOCK_SIZE 256
template <typename T>
static void diag_kernel(T * __restrict__ dst, const T * __restrict__ src,
const int64_t ne0, const int64_t ne1,
const int64_t ne2, const int64_t ne3,
const int64_t total_elements,
const sycl::nd_item<1> & item) {
const int64_t i = item.get_global_id(0);
if (i >= total_elements) {
return;
}
const int64_t i0 = i % ne0;
const int64_t i1 = (i / ne0) % ne1;
const int64_t i2 = (i / (ne0 * ne1)) % ne2;
const int64_t i3 = i / (ne0 * ne1 * ne2);
const int64_t dst_idx = ((i3 * ne2 + i2) * ne1 + i1) * ne0 + i0;
if (i0 == i1) {
const int64_t batch_idx = i3 * ne2 + i2;
dst[dst_idx] = src[batch_idx * ne0 + i0];
} else {
dst[dst_idx] = T(0);
}
(void)ne3;
}
inline void ggml_sycl_op_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->ne[1] == 1);
dpct::queue_ptr stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const void * src0_d = src0->data;
void * dst_d = dst->data;
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
const int64_t ne2 = dst->ne[2];
const int64_t ne3 = dst->ne[3];
const int64_t n_elems = ggml_nelements(dst);
const int64_t num_blocks = (n_elems + SYCL_DIAG_BLOCK_SIZE - 1) / SYCL_DIAG_BLOCK_SIZE;
GGML_ASSERT(dst->type == GGML_TYPE_F32);
stream->parallel_for(
sycl::nd_range<1>(num_blocks * SYCL_DIAG_BLOCK_SIZE, SYCL_DIAG_BLOCK_SIZE),
[=](sycl::nd_item<1> item) {
diag_kernel(static_cast<float *>(dst_d),
static_cast<const float *>(src0_d),
ne0, ne1, ne2, ne3, n_elems, item);
});
}
void ggml_sycl_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_diag(ctx, dst);
}

View file

@ -0,0 +1,5 @@
#pragma once
#include "common.hpp"
void ggml_sycl_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

View file

@ -0,0 +1,55 @@
#include "fill.hpp"
#include "common.hpp"
#define SYCL_FILL_BLOCK_SIZE 256
template <typename T>
static void fill_kernel(T * dst, const int64_t k, const T value,
const sycl::nd_item<1> & item) {
const int64_t i = (int64_t)item.get_global_id(0);
if (i >= k) {
return;
}
dst[i] = value;
}
inline void ggml_sycl_op_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst));
dpct::queue_ptr stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
float value;
memcpy(&value, dst->op_params, sizeof(float));
const int64_t k = ggml_nelements(dst);
const int64_t num_blocks = (k + SYCL_FILL_BLOCK_SIZE - 1) / SYCL_FILL_BLOCK_SIZE;
void * dst_d = dst->data;
switch (dst->type) {
case GGML_TYPE_F32:
stream->parallel_for(
sycl::nd_range<1>(num_blocks * SYCL_FILL_BLOCK_SIZE, SYCL_FILL_BLOCK_SIZE),
[=](sycl::nd_item<1> item) {
fill_kernel(static_cast<float *>(dst_d), k, value, item);
});
break;
case GGML_TYPE_F16:
{
sycl::half h_value = sycl::half(value);
stream->parallel_for(
sycl::nd_range<1>(num_blocks * SYCL_FILL_BLOCK_SIZE, SYCL_FILL_BLOCK_SIZE),
[=](sycl::nd_item<1> item) {
fill_kernel(static_cast<sycl::half *>(dst_d), k, h_value, item);
});
}
break;
default:
GGML_ABORT("unsupported type");
}
}
void ggml_sycl_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0);
ggml_sycl_op_fill(ctx, dst);
}

View file

@ -0,0 +1,5 @@
#pragma once
#include "common.hpp"
void ggml_sycl_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

View file

@ -5,4 +5,5 @@
#include "common.hpp"
#include "ggml.h"
void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

View file

@ -54,7 +54,12 @@
#include "ggml-sycl/set.hpp"
#include "ggml-sycl/ssm_conv.hpp"
#include "ggml-sycl/sycl_hw.hpp"
#include "ggml-sycl/ssm_scan.hpp"
#include "ggml-sycl/fill.hpp"
#include "ggml-sycl/cumsum.hpp"
#include "ggml-sycl/diag.hpp"
#include "ggml-sycl/solve_tri.hpp"
#include "ggml-sycl/gated_delta_net.hpp"
static bool g_sycl_loaded = false;
int g_ggml_sycl_debug = 0;
@ -4394,6 +4399,21 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_SSM_CONV:
ggml_sycl_ssm_conv(ctx, dst);
break;
case GGML_OP_SSM_SCAN:
ggml_sycl_ssm_scan(ctx, dst);
break;
case GGML_OP_FILL:
ggml_sycl_fill(ctx, dst);
break;
case GGML_OP_CUMSUM:
ggml_sycl_cumsum(ctx, dst);
break;
case GGML_OP_DIAG:
ggml_sycl_diag(ctx, dst);
break;
case GGML_OP_SOLVE_TRI:
ggml_sycl_solve_tri(ctx, dst);
break;
case GGML_OP_ROLL:
ggml_sycl_roll(ctx, dst);
break;
@ -5104,6 +5124,21 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return op->type == GGML_TYPE_F32;
case GGML_OP_ARANGE:
return op->type == GGML_TYPE_F32;
case GGML_OP_SSM_SCAN:
if (op->src[3]->ne[0] == 1) {
// Mamba2
// (kernel only supports (d_state == 128 || d_state == 256) && d_head % WARP_SIZE == 0)
return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % WARP_SIZE == 0;
} else {
// TODO Mamba-1 not yet ported to SYCL
return false;
}
case GGML_OP_FILL:
case GGML_OP_CUMSUM:
case GGML_OP_DIAG:
return true;
case GGML_OP_SOLVE_TRI:
return op->src[0]->ne[0] <= SYCL_SOLVE_TRI_MAX_N && op->src[1]->ne[0] <= SYCL_SOLVE_TRI_MAX_K;
case GGML_OP_FLASH_ATTN_EXT:
return ggml_sycl_flash_attn_ext_supported(device, op);
default:

View file

@ -0,0 +1,172 @@
#include "solve_tri.hpp"
#include "common.hpp"
#include <oneapi/mkl/blas.hpp>
template <int n_template, int k_template>
static void solve_tri_f32_fast(const float * __restrict__ A,
const float * __restrict__ B,
float * __restrict__ X,
const int64_t ne02, [[maybe_unused]] const int64_t ne03,
const int64_t nb02, const int64_t nb03,
const int64_t nb12, const int64_t nb13,
const int64_t nb2, const int64_t nb3,
const int n_arg, const int k_arg,
const sycl::nd_item<2> & item, float * sA) {
const int n = n_template == 0 ? n_arg : n_template;
const int k = k_template == 0 ? k_arg : k_template;
const int batch_idx = item.get_group(1);
const int lane = item.get_local_id(1) % WARP_SIZE;
const int col_idx = item.get_local_id(0);
if (col_idx >= k) {
return;
}
const int64_t i03 = batch_idx / ne02;
const int64_t i02 = batch_idx % ne02;
const float * A_batch = (const float *) ((const char *) A + i02 * nb02 + i03 * nb03);
const float * B_batch = (const float *) ((const char *) B + i02 * nb12 + i03 * nb13);
float * X_batch = (float *) ((char *) X + i02 * nb2 + i03 * nb3);
const int offset = item.get_local_id(1) + item.get_local_id(0) * item.get_local_range(1);
#pragma unroll
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
const int i0 = i + offset;
if (i0 < n * n) {
sA[i0] = A_batch[i0];
}
}
item.barrier(sycl::access::fence_space::local_space);
float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
const int half = WARP_SIZE;
const int nrows_low = (n < half) ? n : half;
#pragma unroll
for (int row = 0; row < nrows_low; ++row) {
float sum = 0.0f;
if (lane < row) {
sum += sA[row * n + lane] * x_low;
}
sum = warp_reduce_sum<WARP_SIZE>(sum);
if (lane == row) {
x_low = (x_low - sum) / sA[row * n + row];
}
}
#pragma unroll
for (int row = half; row < n; ++row) {
float sum = sA[row * n + lane] * x_low;
const int j = half + lane;
if (j < row) {
sum += sA[row * n + j] * x_high;
}
sum = warp_reduce_sum<WARP_SIZE>(sum);
if (lane == row - half) {
x_high = (x_high - sum) / sA[row * n + row];
}
}
#pragma unroll
for (int rr = 0; rr < 2; ++rr) {
const int row = rr * WARP_SIZE + lane;
if (row < n) {
const float val = (row < half) ? x_low : x_high;
X_batch[row * k + col_idx] = val;
}
}
}
static void solve_tri_f32_mkl(dpct::queue_ptr stream,
const float * A, float * X,
int n, int k,
int64_t ne02, [[maybe_unused]] int64_t ne03,
int64_t nb02, [[maybe_unused]] int64_t nb03,
int64_t nb2, [[maybe_unused]] int64_t nb3) {
const float alpha = 1.0f;
const int64_t total_batches = ne02 * ne03;
if (total_batches == 0) {
return;
}
const int64_t stride_a = nb02 / sizeof(float);
const int64_t stride_x = nb2 / sizeof(float);
oneapi::mkl::blas::trsm_batch(
*stream,
oneapi::mkl::side::right,
oneapi::mkl::uplo::upper,
oneapi::mkl::transpose::nontrans,
oneapi::mkl::diag::nonunit,
k, n, alpha,
A, n, stride_a,
X, k, stride_x,
total_batches);
}
inline void ggml_sycl_op_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
dpct::queue_ptr stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const int n = src0->ne[0];
const int k = src1->ne[0];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
GGML_ASSERT(n <= SYCL_SOLVE_TRI_MAX_N && k <= SYCL_SOLVE_TRI_MAX_K);
const float * A_d = static_cast<const float *>(src0->data);
const float * B_d = static_cast<const float *>(src1->data);
float * X_d = static_cast<float *>(dst->data);
if (X_d != B_d) {
const int64_t total_elements = (int64_t)n * k * ne02 * ne03;
stream->memcpy(X_d, B_d, total_elements * sizeof(float));
}
const int64_t nb02 = src0->nb[2];
const int64_t nb03 = src0->nb[3];
const int64_t nb12 = src1->nb[2];
const int64_t nb13 = src1->nb[3];
const int64_t nb2 = dst->nb[2];
const int64_t nb3 = dst->nb[3];
const int64_t total_batches = ne02 * ne03;
if (n <= 2 * WARP_SIZE && k <= 32) {
const int smem_size = 2 * WARP_SIZE * 2 * WARP_SIZE;
const sycl::range<2> grid(1, total_batches);
const sycl::range<2> block(k, WARP_SIZE);
stream->submit([&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> smem_acc(sycl::range<1>(smem_size), cgh);
cgh.parallel_for(
sycl::nd_range<2>(grid * block, block),
[=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
solve_tri_f32_fast<0, 0>(A_d, B_d, X_d, ne02, ne03,
nb02, nb03, nb12, nb13, nb2, nb3,
n, k, item, get_pointer(smem_acc));
});
});
} else {
solve_tri_f32_mkl(stream, A_d, X_d, n, k, ne02, ne03, nb02, nb03, nb2, nb3);
}
}
void ggml_sycl_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
ggml_sycl_op_solve_tri(ctx, dst);
}

View file

@ -0,0 +1,8 @@
#pragma once
#include "common.hpp"
#define SYCL_SOLVE_TRI_MAX_N 64
#define SYCL_SOLVE_TRI_MAX_K 64
void ggml_sycl_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

View file

@ -63,7 +63,7 @@ static void kernel_ssm_conv(
});
}
void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
inline void ggml_sycl_op_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0];
ggml_tensor * src1 = dst->src[1];
@ -125,3 +125,8 @@ void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
throw;
}
}
void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
ggml_sycl_op_ssm_conv(ctx, dst);
}

View file

@ -0,0 +1,156 @@
#include "ssm_scan.hpp"
#include "common.hpp"
template <int c_factor, int d_state>
static void ssm_scan_f32_group(
const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
const int32_t * __restrict__ src6, float * __restrict__ dst,
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
const int src2_nb1, const int src2_nb2, const int src3_nb1,
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok,
const sycl::nd_item<2> & item) {
const int lane = item.get_local_id(1) % WARP_SIZE;
const int warp = item.get_local_id(1) / WARP_SIZE;
const int warp_idx = item.get_group(1) * c_factor + warp;
const int seq_idx = item.get_group(0);
const int head_idx = warp_idx / d_head;
const int head_off = (warp_idx % d_head) * sizeof(float);
const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float)));
const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
const float * A_warp = (const float *) ((const char *) src3 + head_idx * src3_nb1);
const float * B_warp = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
const float * C_warp = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
float * y_warp = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx;
float * s_warp = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
const int stride_x = src1_nb2 / sizeof(float);
const int stride_dt = src2_nb1 / sizeof(float);
const int stride_B = src4_nb2 / sizeof(float);
const int stride_C = src5_nb2 / sizeof(float);
const int stride_y = n_head * d_head;
float state[c_factor];
float state_sum = 0.0f;
#pragma unroll
for (int j = 0; j < c_factor; j++) {
state[j] = s0_warp[WARP_SIZE * j + lane];
}
for (int64_t i = 0; i < n_tok; i++) {
const float dt_val = dt_warp[i * stride_dt];
const float dt_soft_plus = (dt_val <= 20.0f ? sycl::log1p(sycl::exp(dt_val)) : dt_val);
state_sum = 0.0f;
const float dA = sycl::exp(dt_soft_plus * A_warp[0]);
const float x_dt = x_warp[i * stride_x] * dt_soft_plus;
#pragma unroll
for (int j = 0; j < c_factor; j++) {
const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane];
const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane];
state[j] = (state[j] * dA) + (B_val * x_dt);
state_sum += state[j] * C_val;
}
state_sum = warp_reduce_sum<WARP_SIZE>(state_sum);
if (lane == 0) {
y_warp[i * stride_y] = state_sum;
}
}
#pragma unroll
for (int j = 0; j < c_factor; j++) {
s_warp[WARP_SIZE * j + lane] = state[j];
}
}
static void ssm_scan_f32_sycl(
const float * src0, const float * src1, const float * src2, const float * src3,
const float * src4, const float * src5, const int32_t * src6, float * dst,
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1,
const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2,
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
dpct::queue_ptr stream) {
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
GGML_ASSERT(src3_nb1 == sizeof(float));
if (d_state == 128) {
constexpr int threads = 128;
constexpr int num_warps = threads / WARP_SIZE;
const sycl::range<2> grid(n_seq, (n_head * head_dim + num_warps - 1) / num_warps);
const sycl::range<2> block(1, threads);
stream->parallel_for(
sycl::nd_range<2>(grid * block, block),
[=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
ssm_scan_f32_group<128 / WARP_SIZE, 128>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok, item);
});
} else if (d_state == 256) {
constexpr int threads = 256;
constexpr int num_warps = threads / WARP_SIZE;
const sycl::range<2> grid(n_seq, (n_head * head_dim + num_warps - 1) / num_warps);
const sycl::range<2> block(1, threads);
stream->parallel_for(
sycl::nd_range<2>(grid * block, block),
[=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
ssm_scan_f32_group<256 / WARP_SIZE, 256>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok, item);
});
} else {
GGML_ABORT("ssm_scan: unsupported d_state (must be 128 or 256)");
}
}
inline void ggml_sycl_op_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
const ggml_tensor * src3 = dst->src[3];
const ggml_tensor * src4 = dst->src[4];
const ggml_tensor * src5 = dst->src[5];
const ggml_tensor * src6 = dst->src[6];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src6->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int64_t nc = src0->ne[0];
const int64_t nr = src0->ne[1];
const int64_t nh = src1->ne[1];
const int64_t ng = src4->ne[1];
const int64_t n_t = src1->ne[2];
const int64_t n_s = src1->ne[3];
const int64_t s_off = ggml_nelements(src1) * sizeof(float);
GGML_ASSERT(ggml_nelements(src1) + nc * nr * nh * n_s == ggml_nelements(dst));
dpct::queue_ptr stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
ssm_scan_f32_sycl(
static_cast<const float *>(src0->data), static_cast<const float *>(src1->data),
static_cast<const float *>(src2->data), static_cast<const float *>(src3->data),
static_cast<const float *>(src4->data), static_cast<const float *>(src5->data),
static_cast<const int32_t *>(src6->data), static_cast<float *>(dst->data),
src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2],
src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3],
s_off, nc, nr, nh, ng, n_t, n_s, stream);
}
void ggml_sycl_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/7);
ggml_sycl_op_ssm_scan(ctx, dst);
}

View file

@ -0,0 +1,5 @@
#pragma once
#include "common.hpp"
void ggml_sycl_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst);