mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-19 08:00:25 +00:00
sycl: add FILL, CUMSUM, DIAG, SOLVE_TRI, SSM_SCAN, GATED_DELTA_NET (#22149)
* sycl: add FILL, CUMSUM, DIAG, SOLVE_TRI, SSM_SCAN, GATED_DELTA_NET Signed-off-by: Chun Tao <chun.tao@intel.com> * Fix abort during test-backend-ops Signed-off-by: Todd Malsbary <todd.malsbary@intel.com> * Regenerate ops.md Signed-off-by: Todd Malsbary <todd.malsbary@intel.com> * Add scope_dbg_print to newly added SYCL ops. Also add scope_dbg_print to existing ssm_conv op. Signed-off-by: Todd Malsbary <todd.malsbary@intel.com> --------- Signed-off-by: Chun Tao <chun.tao@intel.com> Signed-off-by: Todd Malsbary <todd.malsbary@intel.com> Co-authored-by: Chun Tao <chun.tao@intel.com> Co-authored-by: Todd Malsbary <todd.malsbary@intel.com>
This commit is contained in:
parent
b9afc19cb4
commit
ad09224658
15 changed files with 6871 additions and 4113 deletions
12
docs/ops.md
12
docs/ops.md
|
|
@ -17,7 +17,7 @@ Legend:
|
|||
| ABS | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
|
|
@ -36,15 +36,15 @@ Legend:
|
|||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DIAG | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DIAG | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
|
|
@ -101,11 +101,11 @@ Legend:
|
|||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SUM | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
|
|
|
|||
10301
docs/ops/SYCL.csv
10301
docs/ops/SYCL.csv
File diff suppressed because it is too large
Load diff
148
ggml/src/ggml-sycl/cumsum.cpp
Normal file
148
ggml/src/ggml-sycl/cumsum.cpp
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
#include "cumsum.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#define SYCL_CUMSUM_BLOCK_SIZE 256
|
||||
|
||||
static __dpct_inline__ float warp_prefix_inclusive_sum_f32(float x, const sycl::nd_item<3> & item) {
|
||||
return sycl::inclusive_scan_over_group(item.get_sub_group(), x, sycl::plus<float>());
|
||||
}
|
||||
|
||||
static void cumsum_f32_kernel(
|
||||
const float * __restrict__ src, float * __restrict__ dst,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t d1, const int64_t d2, const int64_t d3,
|
||||
const sycl::nd_item<3> & item, float * smem) {
|
||||
|
||||
const int tid = item.get_local_id(2);
|
||||
const int block_size = item.get_local_range(2);
|
||||
const int lane = tid % WARP_SIZE;
|
||||
const int warp = tid / WARP_SIZE;
|
||||
const int warps_per_block = block_size / WARP_SIZE;
|
||||
|
||||
float * s_vals = smem;
|
||||
float * s_warp_sums = smem + block_size;
|
||||
float * s_carry = smem + block_size + warps_per_block;
|
||||
|
||||
if (tid == 0) {
|
||||
s_carry[0] = 0.0f;
|
||||
}
|
||||
item.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
const int64_t i3 = item.get_group(0);
|
||||
const int64_t i2 = item.get_group(1);
|
||||
const int64_t i1 = item.get_group(2);
|
||||
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
|
||||
float * dst_row = dst + i1 * d1 + i2 * d2 + i3 * d3;
|
||||
|
||||
constexpr int num_unroll = 4;
|
||||
float temp[num_unroll];
|
||||
|
||||
for (int64_t i = 0; i < ne00; i += num_unroll * block_size) {
|
||||
int64_t idx = i + tid * num_unroll;
|
||||
|
||||
temp[0] = (idx < ne00 ? src_row[idx] : 0.0f);
|
||||
#pragma unroll
|
||||
for (int j = 1; j < num_unroll; j++) {
|
||||
temp[j] = temp[j - 1];
|
||||
if (idx + j < ne00) {
|
||||
temp[j] += src_row[idx + j];
|
||||
}
|
||||
}
|
||||
|
||||
float val = (idx < ne00) ? temp[num_unroll - 1] : 0.0f;
|
||||
|
||||
val = warp_prefix_inclusive_sum_f32(val, item);
|
||||
s_vals[tid] = val;
|
||||
|
||||
if (lane == WARP_SIZE - 1) {
|
||||
s_warp_sums[warp] = val;
|
||||
}
|
||||
item.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
if (warp == 0) {
|
||||
float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
|
||||
float inc = warp_prefix_inclusive_sum_f32(w, item);
|
||||
if (tid < warps_per_block) {
|
||||
s_warp_sums[tid] = inc - w;
|
||||
}
|
||||
if (tid == warps_per_block - 1) {
|
||||
s_carry[1] = inc;
|
||||
}
|
||||
}
|
||||
item.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
float carry = s_carry[0];
|
||||
float final_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1];
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < num_unroll; j++) {
|
||||
if (idx + j < ne00) {
|
||||
dst_row[idx + j] = temp[j] + final_offset;
|
||||
}
|
||||
}
|
||||
|
||||
item.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
if (tid == 0) {
|
||||
s_carry[0] += s_carry[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
const float * src_d = static_cast<const float *>(src0->data);
|
||||
float * dst_d = static_cast<float *>(dst->data);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
|
||||
const size_t ts = sizeof(float);
|
||||
const int64_t s01 = src0->nb[1] / ts;
|
||||
const int64_t s02 = src0->nb[2] / ts;
|
||||
const int64_t s03 = src0->nb[3] / ts;
|
||||
const int64_t d1 = dst->nb[1] / ts;
|
||||
const int64_t d2 = dst->nb[2] / ts;
|
||||
const int64_t d3 = dst->nb[3] / ts;
|
||||
|
||||
const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE;
|
||||
int block_size = num_warps * WARP_SIZE;
|
||||
block_size = std::min(block_size, SYCL_CUMSUM_BLOCK_SIZE);
|
||||
const int warps_per_block = block_size / WARP_SIZE;
|
||||
const int smem_size = block_size + warps_per_block + 2;
|
||||
|
||||
const sycl::range<3> grid(ne03, ne02, ne01);
|
||||
const sycl::range<3> block(1, 1, block_size);
|
||||
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<float, 1> smem_acc(sycl::range<1>(smem_size), cgh);
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid * block, block),
|
||||
[=](sycl::nd_item<3> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
cumsum_f32_kernel(src_d, dst_d, ne00, ne01, ne02, ne03,
|
||||
s01, s02, s03, d1, d2, d3,
|
||||
item, get_pointer(smem_acc));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_cumsum(ctx, dst);
|
||||
}
|
||||
5
ggml/src/ggml-sycl/cumsum.hpp
Normal file
5
ggml/src/ggml-sycl/cumsum.hpp
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
#pragma once
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
67
ggml/src/ggml-sycl/diag.cpp
Normal file
67
ggml/src/ggml-sycl/diag.cpp
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
#include "diag.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
#define SYCL_DIAG_BLOCK_SIZE 256
|
||||
|
||||
template <typename T>
|
||||
static void diag_kernel(T * __restrict__ dst, const T * __restrict__ src,
|
||||
const int64_t ne0, const int64_t ne1,
|
||||
const int64_t ne2, const int64_t ne3,
|
||||
const int64_t total_elements,
|
||||
const sycl::nd_item<1> & item) {
|
||||
const int64_t i = item.get_global_id(0);
|
||||
if (i >= total_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i0 = i % ne0;
|
||||
const int64_t i1 = (i / ne0) % ne1;
|
||||
const int64_t i2 = (i / (ne0 * ne1)) % ne2;
|
||||
const int64_t i3 = i / (ne0 * ne1 * ne2);
|
||||
|
||||
const int64_t dst_idx = ((i3 * ne2 + i2) * ne1 + i1) * ne0 + i0;
|
||||
|
||||
if (i0 == i1) {
|
||||
const int64_t batch_idx = i3 * ne2 + i2;
|
||||
dst[dst_idx] = src[batch_idx * ne0 + i0];
|
||||
} else {
|
||||
dst[dst_idx] = T(0);
|
||||
}
|
||||
|
||||
(void)ne3;
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(src0->ne[1] == 1);
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
const void * src0_d = src0->data;
|
||||
void * dst_d = dst->data;
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
const int64_t ne3 = dst->ne[3];
|
||||
const int64_t n_elems = ggml_nelements(dst);
|
||||
const int64_t num_blocks = (n_elems + SYCL_DIAG_BLOCK_SIZE - 1) / SYCL_DIAG_BLOCK_SIZE;
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(num_blocks * SYCL_DIAG_BLOCK_SIZE, SYCL_DIAG_BLOCK_SIZE),
|
||||
[=](sycl::nd_item<1> item) {
|
||||
diag_kernel(static_cast<float *>(dst_d),
|
||||
static_cast<const float *>(src0_d),
|
||||
ne0, ne1, ne2, ne3, n_elems, item);
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_diag(ctx, dst);
|
||||
}
|
||||
5
ggml/src/ggml-sycl/diag.hpp
Normal file
5
ggml/src/ggml-sycl/diag.hpp
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
#pragma once
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
55
ggml/src/ggml-sycl/fill.cpp
Normal file
55
ggml/src/ggml-sycl/fill.cpp
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
#include "fill.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
#define SYCL_FILL_BLOCK_SIZE 256
|
||||
|
||||
template <typename T>
|
||||
static void fill_kernel(T * dst, const int64_t k, const T value,
|
||||
const sycl::nd_item<1> & item) {
|
||||
const int64_t i = (int64_t)item.get_global_id(0);
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
dst[i] = value;
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
float value;
|
||||
memcpy(&value, dst->op_params, sizeof(float));
|
||||
|
||||
const int64_t k = ggml_nelements(dst);
|
||||
const int64_t num_blocks = (k + SYCL_FILL_BLOCK_SIZE - 1) / SYCL_FILL_BLOCK_SIZE;
|
||||
void * dst_d = dst->data;
|
||||
|
||||
switch (dst->type) {
|
||||
case GGML_TYPE_F32:
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(num_blocks * SYCL_FILL_BLOCK_SIZE, SYCL_FILL_BLOCK_SIZE),
|
||||
[=](sycl::nd_item<1> item) {
|
||||
fill_kernel(static_cast<float *>(dst_d), k, value, item);
|
||||
});
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
sycl::half h_value = sycl::half(value);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(num_blocks * SYCL_FILL_BLOCK_SIZE, SYCL_FILL_BLOCK_SIZE),
|
||||
[=](sycl::nd_item<1> item) {
|
||||
fill_kernel(static_cast<sycl::half *>(dst_d), k, h_value, item);
|
||||
});
|
||||
}
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("unsupported type");
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0);
|
||||
ggml_sycl_op_fill(ctx, dst);
|
||||
}
|
||||
5
ggml/src/ggml-sycl/fill.hpp
Normal file
5
ggml/src/ggml-sycl/fill.hpp
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
#pragma once
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -5,4 +5,5 @@
|
|||
#include "common.hpp"
|
||||
#include "ggml.h"
|
||||
|
||||
void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
void ggml_sycl_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
|
|
|||
|
|
@ -54,7 +54,12 @@
|
|||
#include "ggml-sycl/set.hpp"
|
||||
#include "ggml-sycl/ssm_conv.hpp"
|
||||
#include "ggml-sycl/sycl_hw.hpp"
|
||||
|
||||
#include "ggml-sycl/ssm_scan.hpp"
|
||||
#include "ggml-sycl/fill.hpp"
|
||||
#include "ggml-sycl/cumsum.hpp"
|
||||
#include "ggml-sycl/diag.hpp"
|
||||
#include "ggml-sycl/solve_tri.hpp"
|
||||
#include "ggml-sycl/gated_delta_net.hpp"
|
||||
|
||||
static bool g_sycl_loaded = false;
|
||||
int g_ggml_sycl_debug = 0;
|
||||
|
|
@ -4394,6 +4399,21 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|||
case GGML_OP_SSM_CONV:
|
||||
ggml_sycl_ssm_conv(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SSM_SCAN:
|
||||
ggml_sycl_ssm_scan(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_FILL:
|
||||
ggml_sycl_fill(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CUMSUM:
|
||||
ggml_sycl_cumsum(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_DIAG:
|
||||
ggml_sycl_diag(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
ggml_sycl_solve_tri(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ROLL:
|
||||
ggml_sycl_roll(ctx, dst);
|
||||
break;
|
||||
|
|
@ -5104,6 +5124,21 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
return op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ARANGE:
|
||||
return op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SSM_SCAN:
|
||||
if (op->src[3]->ne[0] == 1) {
|
||||
// Mamba2
|
||||
// (kernel only supports (d_state == 128 || d_state == 256) && d_head % WARP_SIZE == 0)
|
||||
return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % WARP_SIZE == 0;
|
||||
} else {
|
||||
// TODO Mamba-1 not yet ported to SYCL
|
||||
return false;
|
||||
}
|
||||
case GGML_OP_FILL:
|
||||
case GGML_OP_CUMSUM:
|
||||
case GGML_OP_DIAG:
|
||||
return true;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
return op->src[0]->ne[0] <= SYCL_SOLVE_TRI_MAX_N && op->src[1]->ne[0] <= SYCL_SOLVE_TRI_MAX_K;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
return ggml_sycl_flash_attn_ext_supported(device, op);
|
||||
default:
|
||||
|
|
|
|||
172
ggml/src/ggml-sycl/solve_tri.cpp
Normal file
172
ggml/src/ggml-sycl/solve_tri.cpp
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
#include "solve_tri.hpp"
|
||||
#include "common.hpp"
|
||||
#include <oneapi/mkl/blas.hpp>
|
||||
|
||||
template <int n_template, int k_template>
|
||||
static void solve_tri_f32_fast(const float * __restrict__ A,
|
||||
const float * __restrict__ B,
|
||||
float * __restrict__ X,
|
||||
const int64_t ne02, [[maybe_unused]] const int64_t ne03,
|
||||
const int64_t nb02, const int64_t nb03,
|
||||
const int64_t nb12, const int64_t nb13,
|
||||
const int64_t nb2, const int64_t nb3,
|
||||
const int n_arg, const int k_arg,
|
||||
const sycl::nd_item<2> & item, float * sA) {
|
||||
|
||||
const int n = n_template == 0 ? n_arg : n_template;
|
||||
const int k = k_template == 0 ? k_arg : k_template;
|
||||
|
||||
const int batch_idx = item.get_group(1);
|
||||
const int lane = item.get_local_id(1) % WARP_SIZE;
|
||||
const int col_idx = item.get_local_id(0);
|
||||
|
||||
if (col_idx >= k) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i03 = batch_idx / ne02;
|
||||
const int64_t i02 = batch_idx % ne02;
|
||||
|
||||
const float * A_batch = (const float *) ((const char *) A + i02 * nb02 + i03 * nb03);
|
||||
const float * B_batch = (const float *) ((const char *) B + i02 * nb12 + i03 * nb13);
|
||||
float * X_batch = (float *) ((char *) X + i02 * nb2 + i03 * nb3);
|
||||
|
||||
const int offset = item.get_local_id(1) + item.get_local_id(0) * item.get_local_range(1);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
|
||||
const int i0 = i + offset;
|
||||
if (i0 < n * n) {
|
||||
sA[i0] = A_batch[i0];
|
||||
}
|
||||
}
|
||||
|
||||
item.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
|
||||
float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
|
||||
|
||||
const int half = WARP_SIZE;
|
||||
const int nrows_low = (n < half) ? n : half;
|
||||
|
||||
#pragma unroll
|
||||
for (int row = 0; row < nrows_low; ++row) {
|
||||
float sum = 0.0f;
|
||||
if (lane < row) {
|
||||
sum += sA[row * n + lane] * x_low;
|
||||
}
|
||||
sum = warp_reduce_sum<WARP_SIZE>(sum);
|
||||
if (lane == row) {
|
||||
x_low = (x_low - sum) / sA[row * n + row];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int row = half; row < n; ++row) {
|
||||
float sum = sA[row * n + lane] * x_low;
|
||||
const int j = half + lane;
|
||||
if (j < row) {
|
||||
sum += sA[row * n + j] * x_high;
|
||||
}
|
||||
sum = warp_reduce_sum<WARP_SIZE>(sum);
|
||||
if (lane == row - half) {
|
||||
x_high = (x_high - sum) / sA[row * n + row];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int rr = 0; rr < 2; ++rr) {
|
||||
const int row = rr * WARP_SIZE + lane;
|
||||
if (row < n) {
|
||||
const float val = (row < half) ? x_low : x_high;
|
||||
X_batch[row * k + col_idx] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void solve_tri_f32_mkl(dpct::queue_ptr stream,
|
||||
const float * A, float * X,
|
||||
int n, int k,
|
||||
int64_t ne02, [[maybe_unused]] int64_t ne03,
|
||||
int64_t nb02, [[maybe_unused]] int64_t nb03,
|
||||
int64_t nb2, [[maybe_unused]] int64_t nb3) {
|
||||
const float alpha = 1.0f;
|
||||
const int64_t total_batches = ne02 * ne03;
|
||||
if (total_batches == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t stride_a = nb02 / sizeof(float);
|
||||
const int64_t stride_x = nb2 / sizeof(float);
|
||||
|
||||
oneapi::mkl::blas::trsm_batch(
|
||||
*stream,
|
||||
oneapi::mkl::side::right,
|
||||
oneapi::mkl::uplo::upper,
|
||||
oneapi::mkl::transpose::nontrans,
|
||||
oneapi::mkl::diag::nonunit,
|
||||
k, n, alpha,
|
||||
A, n, stride_a,
|
||||
X, k, stride_x,
|
||||
total_batches);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
const int n = src0->ne[0];
|
||||
const int k = src1->ne[0];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
|
||||
GGML_ASSERT(n <= SYCL_SOLVE_TRI_MAX_N && k <= SYCL_SOLVE_TRI_MAX_K);
|
||||
|
||||
const float * A_d = static_cast<const float *>(src0->data);
|
||||
const float * B_d = static_cast<const float *>(src1->data);
|
||||
float * X_d = static_cast<float *>(dst->data);
|
||||
|
||||
if (X_d != B_d) {
|
||||
const int64_t total_elements = (int64_t)n * k * ne02 * ne03;
|
||||
stream->memcpy(X_d, B_d, total_elements * sizeof(float));
|
||||
}
|
||||
|
||||
const int64_t nb02 = src0->nb[2];
|
||||
const int64_t nb03 = src0->nb[3];
|
||||
const int64_t nb12 = src1->nb[2];
|
||||
const int64_t nb13 = src1->nb[3];
|
||||
const int64_t nb2 = dst->nb[2];
|
||||
const int64_t nb3 = dst->nb[3];
|
||||
|
||||
const int64_t total_batches = ne02 * ne03;
|
||||
|
||||
if (n <= 2 * WARP_SIZE && k <= 32) {
|
||||
const int smem_size = 2 * WARP_SIZE * 2 * WARP_SIZE;
|
||||
const sycl::range<2> grid(1, total_batches);
|
||||
const sycl::range<2> block(k, WARP_SIZE);
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<float, 1> smem_acc(sycl::range<1>(smem_size), cgh);
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<2>(grid * block, block),
|
||||
[=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
solve_tri_f32_fast<0, 0>(A_d, B_d, X_d, ne02, ne03,
|
||||
nb02, nb03, nb12, nb13, nb2, nb3,
|
||||
n, k, item, get_pointer(smem_acc));
|
||||
});
|
||||
});
|
||||
} else {
|
||||
solve_tri_f32_mkl(stream, A_d, X_d, n, k, ne02, ne03, nb02, nb03, nb2, nb3);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
ggml_sycl_op_solve_tri(ctx, dst);
|
||||
}
|
||||
8
ggml/src/ggml-sycl/solve_tri.hpp
Normal file
8
ggml/src/ggml-sycl/solve_tri.hpp
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
#pragma once
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#define SYCL_SOLVE_TRI_MAX_N 64
|
||||
#define SYCL_SOLVE_TRI_MAX_K 64
|
||||
|
||||
void ggml_sycl_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -63,7 +63,7 @@ static void kernel_ssm_conv(
|
|||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
inline void ggml_sycl_op_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
|
|
@ -125,3 +125,8 @@ void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
ggml_sycl_op_ssm_conv(ctx, dst);
|
||||
}
|
||||
|
|
|
|||
156
ggml/src/ggml-sycl/ssm_scan.cpp
Normal file
156
ggml/src/ggml-sycl/ssm_scan.cpp
Normal file
|
|
@ -0,0 +1,156 @@
|
|||
#include "ssm_scan.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
template <int c_factor, int d_state>
|
||||
static void ssm_scan_f32_group(
|
||||
const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
|
||||
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
|
||||
const int32_t * __restrict__ src6, float * __restrict__ dst,
|
||||
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
|
||||
const int src2_nb1, const int src2_nb2, const int src3_nb1,
|
||||
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
|
||||
const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok,
|
||||
const sycl::nd_item<2> & item) {
|
||||
|
||||
const int lane = item.get_local_id(1) % WARP_SIZE;
|
||||
const int warp = item.get_local_id(1) / WARP_SIZE;
|
||||
const int warp_idx = item.get_group(1) * c_factor + warp;
|
||||
const int seq_idx = item.get_group(0);
|
||||
|
||||
const int head_idx = warp_idx / d_head;
|
||||
const int head_off = (warp_idx % d_head) * sizeof(float);
|
||||
const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
|
||||
|
||||
const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
|
||||
const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float)));
|
||||
const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
|
||||
const float * A_warp = (const float *) ((const char *) src3 + head_idx * src3_nb1);
|
||||
const float * B_warp = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
|
||||
const float * C_warp = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
|
||||
float * y_warp = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx;
|
||||
float * s_warp = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
|
||||
|
||||
const int stride_x = src1_nb2 / sizeof(float);
|
||||
const int stride_dt = src2_nb1 / sizeof(float);
|
||||
const int stride_B = src4_nb2 / sizeof(float);
|
||||
const int stride_C = src5_nb2 / sizeof(float);
|
||||
const int stride_y = n_head * d_head;
|
||||
|
||||
float state[c_factor];
|
||||
float state_sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < c_factor; j++) {
|
||||
state[j] = s0_warp[WARP_SIZE * j + lane];
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < n_tok; i++) {
|
||||
const float dt_val = dt_warp[i * stride_dt];
|
||||
const float dt_soft_plus = (dt_val <= 20.0f ? sycl::log1p(sycl::exp(dt_val)) : dt_val);
|
||||
|
||||
state_sum = 0.0f;
|
||||
const float dA = sycl::exp(dt_soft_plus * A_warp[0]);
|
||||
const float x_dt = x_warp[i * stride_x] * dt_soft_plus;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < c_factor; j++) {
|
||||
const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane];
|
||||
const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane];
|
||||
state[j] = (state[j] * dA) + (B_val * x_dt);
|
||||
state_sum += state[j] * C_val;
|
||||
}
|
||||
|
||||
state_sum = warp_reduce_sum<WARP_SIZE>(state_sum);
|
||||
|
||||
if (lane == 0) {
|
||||
y_warp[i * stride_y] = state_sum;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < c_factor; j++) {
|
||||
s_warp[WARP_SIZE * j + lane] = state[j];
|
||||
}
|
||||
}
|
||||
|
||||
static void ssm_scan_f32_sycl(
|
||||
const float * src0, const float * src1, const float * src2, const float * src3,
|
||||
const float * src4, const float * src5, const int32_t * src6, float * dst,
|
||||
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1,
|
||||
const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2,
|
||||
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
|
||||
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
|
||||
dpct::queue_ptr stream) {
|
||||
|
||||
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
|
||||
GGML_ASSERT(src3_nb1 == sizeof(float));
|
||||
if (d_state == 128) {
|
||||
constexpr int threads = 128;
|
||||
constexpr int num_warps = threads / WARP_SIZE;
|
||||
const sycl::range<2> grid(n_seq, (n_head * head_dim + num_warps - 1) / num_warps);
|
||||
const sycl::range<2> block(1, threads);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<2>(grid * block, block),
|
||||
[=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
ssm_scan_f32_group<128 / WARP_SIZE, 128>(
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
|
||||
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok, item);
|
||||
});
|
||||
} else if (d_state == 256) {
|
||||
constexpr int threads = 256;
|
||||
constexpr int num_warps = threads / WARP_SIZE;
|
||||
const sycl::range<2> grid(n_seq, (n_head * head_dim + num_warps - 1) / num_warps);
|
||||
const sycl::range<2> block(1, threads);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<2>(grid * block, block),
|
||||
[=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
ssm_scan_f32_group<256 / WARP_SIZE, 256>(
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
|
||||
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok, item);
|
||||
});
|
||||
} else {
|
||||
GGML_ABORT("ssm_scan: unsupported d_state (must be 128 or 256)");
|
||||
}
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
const ggml_tensor * src3 = dst->src[3];
|
||||
const ggml_tensor * src4 = dst->src[4];
|
||||
const ggml_tensor * src5 = dst->src[5];
|
||||
const ggml_tensor * src6 = dst->src[6];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src6->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t nc = src0->ne[0];
|
||||
const int64_t nr = src0->ne[1];
|
||||
const int64_t nh = src1->ne[1];
|
||||
const int64_t ng = src4->ne[1];
|
||||
const int64_t n_t = src1->ne[2];
|
||||
const int64_t n_s = src1->ne[3];
|
||||
const int64_t s_off = ggml_nelements(src1) * sizeof(float);
|
||||
|
||||
GGML_ASSERT(ggml_nelements(src1) + nc * nr * nh * n_s == ggml_nelements(dst));
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
ssm_scan_f32_sycl(
|
||||
static_cast<const float *>(src0->data), static_cast<const float *>(src1->data),
|
||||
static_cast<const float *>(src2->data), static_cast<const float *>(src3->data),
|
||||
static_cast<const float *>(src4->data), static_cast<const float *>(src5->data),
|
||||
static_cast<const int32_t *>(src6->data), static_cast<float *>(dst->data),
|
||||
src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2],
|
||||
src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3],
|
||||
s_off, nc, nr, nh, ng, n_t, n_s, stream);
|
||||
}
|
||||
|
||||
void ggml_sycl_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/7);
|
||||
ggml_sycl_op_ssm_scan(ctx, dst);
|
||||
}
|
||||
5
ggml/src/ggml-sycl/ssm_scan.hpp
Normal file
5
ggml/src/ggml-sycl/ssm_scan.hpp
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
#pragma once
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
Loading…
Add table
Add a link
Reference in a new issue