mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-22 11:16:08 +00:00
hexagon: ssm-conv fix for large prompts (#23307)
* hexagon: remove gathers and better handling of vtcm in ssm-conv * hexagon: relax ssm-conv gating requirements * hexagon: add new prefill ssm-conv backend test * hexagon: remove trailing white space * hex-rope: uninline rope_cache_init, otherwise it breaks after rebaseing with SSM_CONV changes --------- Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
This commit is contained in:
parent
ce02093fdd
commit
0be84685bd
4 changed files with 252 additions and 158 deletions
|
|
@ -2735,9 +2735,10 @@ static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session *
|
|||
if (dst->ne[0] != d_inner || dst->ne[1] != n_t || dst->ne[2] != n_s) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: add support for non-contiguous tensors
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
|
||||
if (src0->nb[0] != sizeof(float) || src1->nb[0] != sizeof(float) || dst->nb[0] != sizeof(float)) {
|
||||
return false;
|
||||
}
|
||||
if (src0->nb[1] != src0->ne[0] * sizeof(float) || src1->nb[1] != src1->ne[0] * sizeof(float)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ static inline void rope_yarn_one(float theta, float freq_scale, float * corr_dim
|
|||
cache[i0 + 1] = sinf(theta_final) * mscale_final;
|
||||
}
|
||||
|
||||
static void rope_cache_init(const float theta_base,
|
||||
static __attribute__((noinline)) void rope_cache_init(const float theta_base,
|
||||
const float freq_scale,
|
||||
const float * freq_factors,
|
||||
float * corr_dims,
|
||||
|
|
@ -129,7 +129,7 @@ static void rope_cache_init(const float theta_base,
|
|||
|
||||
// pos_t/h/w/e: the four position ids for this sequence step (t=time, h=height, w=width, e=extra).
|
||||
// sections[4]: number of head dims assigned to each position component.
|
||||
static void mrope_cache_init(const float pos_t,
|
||||
static __attribute__((noinline)) void mrope_cache_init(const float pos_t,
|
||||
const float pos_h,
|
||||
const float pos_w,
|
||||
const float pos_e,
|
||||
|
|
|
|||
|
|
@ -20,55 +20,56 @@
|
|||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define htp_ssm_conv_tensors_preamble \
|
||||
const struct htp_tensor * restrict src0 = octx->src[0]; \
|
||||
const struct htp_tensor * restrict src1 = octx->src[1]; \
|
||||
const struct htp_tensor * restrict dst = octx->dst; \
|
||||
struct htp_spad * restrict src0_spad = &octx->src0_spad; \
|
||||
struct htp_spad * restrict src1_spad = &octx->src1_spad; \
|
||||
struct htp_spad * restrict dst_spad = &octx->dst_spad; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = src1->ne[0]; \
|
||||
const uint32_t ne11 = src1->ne[1]; \
|
||||
const uint32_t ne12 = src1->ne[2]; \
|
||||
const uint32_t ne13 = src1->ne[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = src1->nb[0]; \
|
||||
const uint32_t nb11 = src1->nb[1]; \
|
||||
const uint32_t nb12 = src1->nb[2]; \
|
||||
const uint32_t nb13 = src1->nb[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
#define htp_ssm_conv_tensors_preamble \
|
||||
const struct htp_tensor * restrict src0 = octx->src[0]; \
|
||||
const struct htp_tensor * restrict src1 = octx->src[1]; \
|
||||
const struct htp_tensor * restrict dst = octx->dst; \
|
||||
struct htp_spad * restrict src0_spad = &octx->src0_spad; \
|
||||
struct htp_spad * restrict src1_spad = &octx->src1_spad; \
|
||||
struct htp_spad * restrict dst_spad = &octx->dst_spad; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = src1->ne[0]; \
|
||||
const uint32_t ne11 = src1->ne[1]; \
|
||||
const uint32_t ne12 = src1->ne[2]; \
|
||||
const uint32_t ne13 = src1->ne[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = src1->nb[0]; \
|
||||
const uint32_t nb11 = src1->nb[1]; \
|
||||
const uint32_t nb12 = src1->nb[2]; \
|
||||
const uint32_t nb13 = src1->nb[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
struct htp_ssm_conv_context {
|
||||
struct htp_ops_context * octx;
|
||||
uint32_t nrows_per_thread;
|
||||
uint32_t d_inner_tile;
|
||||
uint64_t t_start;
|
||||
};
|
||||
|
||||
#define htp_ssm_conv_preamble \
|
||||
#define htp_ssm_conv_preamble \
|
||||
struct htp_ssm_conv_context * scctx = (struct htp_ssm_conv_context *) data; \
|
||||
struct htp_ops_context * octx = scctx->octx; \
|
||||
htp_ssm_conv_tensors_preamble; \
|
||||
dma_queue * dma_queue = octx->ctx->dma[ith];
|
||||
struct htp_ops_context * octx = scctx->octx; \
|
||||
htp_ssm_conv_tensors_preamble; \
|
||||
dma_queue * dma_queue = octx->ctx->dma[ith];
|
||||
|
||||
// Scalar FP32 SSM_CONV implementation
|
||||
static void ssm_conv_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
|
||||
|
|
@ -128,118 +129,211 @@ static void ssm_conv_thread_f32_f32(unsigned int nth, unsigned int ith, void *da
|
|||
dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
// HVX FP32 SSM_CONV implementation - vectorizes across d_inner dimension
|
||||
|
||||
// In-register 32x32 fp32 transpose using std 5-stage HVX vshuff butterfly.
|
||||
static inline void hvx_transpose_32x32_f32(HVX_Vector m[32]) {
|
||||
HVX_Vector tmp[32];
|
||||
|
||||
// Stage 0 (R = -4): pair (2i, 2i+1) for i = 0..15. m -> tmp.
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
HVX_VectorPair p = Q6_W_vshuff_VVR(m[2*i + 1], m[2*i], -4);
|
||||
tmp[2*i + 0] = Q6_V_lo_W(p);
|
||||
tmp[2*i + 1] = Q6_V_hi_W(p);
|
||||
}
|
||||
|
||||
// Stage 1 (R = -8): per block of 4, pair (b+0, b+2) and (b+1, b+3). tmp -> m.
|
||||
for (int b = 0; b < 32; b += 4) {
|
||||
HVX_VectorPair p0 = Q6_W_vshuff_VVR(tmp[b + 2], tmp[b + 0], -8);
|
||||
HVX_VectorPair p1 = Q6_W_vshuff_VVR(tmp[b + 3], tmp[b + 1], -8);
|
||||
m[b + 0] = Q6_V_lo_W(p0); m[b + 1] = Q6_V_hi_W(p0);
|
||||
m[b + 2] = Q6_V_lo_W(p1); m[b + 3] = Q6_V_hi_W(p1);
|
||||
}
|
||||
|
||||
// Stage 2 (R = -16): per block of 8, pair (b+i, b+i+4) for i = 0..3. m -> tmp.
|
||||
for (int b = 0; b < 32; b += 8) {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
HVX_VectorPair p = Q6_W_vshuff_VVR(m[b + i + 4], m[b + i], -16);
|
||||
tmp[b + 2*i + 0] = Q6_V_lo_W(p);
|
||||
tmp[b + 2*i + 1] = Q6_V_hi_W(p);
|
||||
}
|
||||
}
|
||||
|
||||
// Stage 3 (R = -32): per block of 16, pair (b+i, b+i+8) for i = 0..7. tmp -> m.
|
||||
for (int b = 0; b < 32; b += 16) {
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
HVX_VectorPair p = Q6_W_vshuff_VVR(tmp[b + i + 8], tmp[b + i], -32);
|
||||
m[b + 2*i + 0] = Q6_V_lo_W(p);
|
||||
m[b + 2*i + 1] = Q6_V_hi_W(p);
|
||||
}
|
||||
}
|
||||
|
||||
// Stage 4 (R = -64): pair (i, i+16) for i = 0..15. m -> tmp -> m.
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
HVX_VectorPair p = Q6_W_vshuff_VVR(m[i + 16], m[i], -64);
|
||||
tmp[2 * i + 0] = Q6_V_lo_W(p);
|
||||
tmp[2 * i + 1] = Q6_V_hi_W(p);
|
||||
}
|
||||
|
||||
for (int i = 0; i < 32; ++i) {
|
||||
m[i] = tmp[i];
|
||||
}
|
||||
}
|
||||
|
||||
// HVX FP32 SSM_CONV implementation - channel-vectorized HVX kernel with src0/src1
|
||||
// transposed into VTCM.
|
||||
//
|
||||
// VTCM layouts (per thread):
|
||||
// src1_T : {d_inner_per_thread, d_conv} — staged once per launch (small).
|
||||
// src0_T : {d_inner_tile, ncs} — staged per d_inner-tile.
|
||||
//
|
||||
// d_inner_tile is chosen so that per-thread VTCM stays under the budget.
|
||||
// Each thread iterates ceil(d_inner_per_thread d_inner_tile) tiles serially.
|
||||
#define HTP_SSM_CONV_VTCM_BUDGET (1u << 20) // 1 MiB per thread
|
||||
|
||||
// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_per_thread, d_conv} (VTCM)
|
||||
static inline void transpose_src1(const float * src1_data,
|
||||
uint32_t src1_stride_inner,
|
||||
uint32_t i1_off,
|
||||
uint32_t d_inner_per_thread,
|
||||
uint32_t d_conv,
|
||||
float * src1_T) {
|
||||
for (uint32_t i = 0; i < d_inner_per_thread; ++i) {
|
||||
const float * src_row = src1_data + (i1_off + i) * src1_stride_inner;
|
||||
for (uint32_t j = 0; j < d_conv; ++j) {
|
||||
src1_T[j * d_inner_per_thread + i] = src_row[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HVX 32x32 src0 transpose: src0 {ncs, d_inner} (DDR) -> src0_T {d_inner_tile, ncs} (VTCM)
|
||||
static inline void transpose_src0_block(const float * src0_block,
|
||||
uint32_t ncs,
|
||||
uint32_t cb_n,
|
||||
uint32_t d_inner_tile,
|
||||
float * src0_T_block_dst,
|
||||
uint32_t cb /* dst column offset */) {
|
||||
const uint32_t T_TILE = VLEN_FP32;
|
||||
|
||||
HVX_Vector __attribute__((aligned(VLEN))) sub[32];
|
||||
|
||||
for (uint32_t t0 = 0; t0 < ncs; t0 += T_TILE) {
|
||||
const uint32_t t_n = MIN(T_TILE, ncs - t0);
|
||||
|
||||
// Load 32 rows (channels) of T_TILE samples; pad missing channels with zeros.
|
||||
for (uint32_t r = 0; r < cb_n; ++r) {
|
||||
const float * src_row = src0_block + r * ncs + t0;
|
||||
if (t_n == T_TILE) {
|
||||
sub[r] = *(const HVX_UVector *) src_row;
|
||||
} else {
|
||||
HVX_Vector v = hvx_vec_splat_f32(0.0f);
|
||||
hvx_vec_store_u(&v, t_n * sizeof(float), hvx_vec_splat_f32(0.0f));
|
||||
|
||||
float __attribute__((aligned(VLEN))) tmp[VLEN_FP32] = { 0 };
|
||||
for (uint32_t k = 0; k < t_n; ++k) tmp[k] = src_row[k];
|
||||
v = *(const HVX_Vector *) tmp;
|
||||
sub[r] = v;
|
||||
}
|
||||
}
|
||||
for (uint32_t r = cb_n; r < T_TILE; ++r) {
|
||||
sub[r] = hvx_vec_splat_f32(0.0f);
|
||||
}
|
||||
|
||||
hvx_transpose_32x32_f32(sub);
|
||||
|
||||
// Store transposed sub-tile to src0_T at offsets (t0 + j) * d_inner_tile + cb.
|
||||
// Only write the valid t_n rows of the transposed result.
|
||||
for (uint32_t r = 0; r < t_n; ++r) {
|
||||
float * dst = src0_T_block_dst + (t0 + r) * d_inner_tile + cb;
|
||||
if (cb_n == T_TILE) {
|
||||
*(HVX_UVector *) dst = sub[r];
|
||||
} else {
|
||||
hvx_vec_store_u(dst, cb_n * sizeof(float), sub[r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) {
|
||||
htp_ssm_conv_preamble;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const int nc = src1->ne[0]; // d_conv
|
||||
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
|
||||
|
||||
const uint32_t d_conv = src1->ne[0];
|
||||
const uint32_t d_inner = src0->ne[1];
|
||||
const uint32_t n_t = dst->ne[1];
|
||||
const uint32_t n_s = dst->ne[2];
|
||||
const uint32_t ncs = src0->ne[0];
|
||||
|
||||
const uint32_t src0_stride_inner = src0->nb[1] / sizeof(float);
|
||||
const uint32_t src0_stride_seq = src0->nb[2] / sizeof(float);
|
||||
const uint32_t src1_stride_inner = src1->nb[1] / sizeof(float);
|
||||
const uint32_t dst_stride_token = dst->nb[1] / sizeof(float);
|
||||
const uint32_t dst_stride_seq = dst->nb[2] / sizeof(float);
|
||||
|
||||
const uint32_t dr = scctx->nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = MIN(ir0 + dr, d_inner);
|
||||
|
||||
if (ir0 >= ir1) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t d_inner_per_thread = ir1 - ir0;
|
||||
const uint32_t d_inner_tile = scctx->d_inner_tile;
|
||||
|
||||
const float * src0_data = (const float *) src0->data;
|
||||
const float * src1_data = (const float *) src1->data;
|
||||
float * dst_data = (float *) dst->data;
|
||||
float * dst_data = (float *) dst->data;
|
||||
|
||||
// Calculate row range for this thread
|
||||
const int dr = scctx->nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = MIN(ir0 + dr, d_inner);
|
||||
const uint32_t ir = ir1 - ir0;
|
||||
// Per-thread VTCM regions.
|
||||
float * src0_T = (float *)(octx->src0_spad.data + ith * octx->src0_spad.size_per_thread);
|
||||
float * src1_T = (float *)(octx->src1_spad.data + ith * octx->src1_spad.size_per_thread);
|
||||
|
||||
if (ir0 >= ir1) {
|
||||
return; // No work for this thread
|
||||
}
|
||||
// Stage src1 weights once into VTCM in {d_inner_per_thread, d_conv} layout.
|
||||
transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_conv, src1_T);
|
||||
|
||||
// src0 and src1 gather offsets
|
||||
uint32_t __attribute__((aligned(VLEN))) src0_offsets[VLEN_FP32] = { 0 };
|
||||
uint32_t __attribute__((aligned(VLEN))) src1_offsets[VLEN_FP32] = { 0 };
|
||||
|
||||
for (uint32_t i = 0; i < VLEN_FP32; ++i) {
|
||||
src0_offsets[i] = i * (ncs) * sizeof(float);
|
||||
src1_offsets[i] = i * (d_conv) * sizeof(float);
|
||||
}
|
||||
|
||||
const uint32_t src0_gather_len = VLEN * ncs;
|
||||
const uint32_t src1_gather_len = VLEN * d_conv;
|
||||
|
||||
// gather scratchpads
|
||||
HVX_Vector * src0_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + 0);
|
||||
HVX_Vector * src1_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + VLEN);
|
||||
|
||||
float * data_src0 = (float *) ((char *) src0->data + ir0 * src0->nb[1]);
|
||||
float * data_src1 = (float *) ((char *) src1->data + ir0 * src1->nb[1]);
|
||||
|
||||
uint8_t * spad_src0 = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread;
|
||||
uint8_t * spad_src1 = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread;
|
||||
|
||||
// copy src1 workload to VTCM
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src1, data_src1), nb11, nb11, ir);
|
||||
|
||||
// FARF(HIGH, "ssm-conv-src1-fetch %d: ir0 %u size %u\n", ith, ir0, nb11 * ir);
|
||||
const uint32_t C_TILE = VLEN_FP32;
|
||||
|
||||
for (uint32_t i3 = 0; i3 < n_s; ++i3) {
|
||||
float * src0_data_ptr = (float *) ((char *) data_src0 + i3 * (src0->nb[2]));
|
||||
for (uint32_t tile_off = 0; tile_off < d_inner_per_thread; tile_off += d_inner_tile) {
|
||||
const uint32_t tile_n = MIN(d_inner_tile, d_inner_per_thread - tile_off);
|
||||
|
||||
// copy src0 workload to VTCM
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0, src0_data_ptr), nb01, nb01, ir);
|
||||
// Place src0 chunk into VTCM in {d_inner_tile, ncs} layout.
|
||||
const float * src0_block = src0_data + i3 * src0_stride_seq + (ir0 + tile_off) * src0_stride_inner;
|
||||
|
||||
// FARF(HIGH, "ssm-conv-src0-fetch %d: ir0 %u i3 %u size %u\n", ith, ir0, i3, nb01 * ir);
|
||||
|
||||
dma_queue_flush(dma_queue);
|
||||
|
||||
for (uint32_t i2 = 0; i2 < n_t; ++i2) {
|
||||
float * dst_ptr = (float *) ((char *) dst->data + ir0 * (dst->nb[0]) + i2 * (dst->nb[1]) + i3 * (dst->nb[2]));
|
||||
|
||||
const uint32_t nvec = ir / VLEN_FP32;
|
||||
const uint32_t nloe = ir % VLEN_FP32;
|
||||
uint32_t i1 = 0;
|
||||
|
||||
for (uint32_t vi1 = 0; vi1 < nvec; vi1++) {
|
||||
HVX_Vector acc_vec = Q6_V_vsplat_R(0);
|
||||
|
||||
for (uint32_t i0 = 0; i0 < d_conv; ++i0) {
|
||||
uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]);
|
||||
uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc) * sizeof(float);
|
||||
Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets));
|
||||
Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets));
|
||||
|
||||
HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec);
|
||||
acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod);
|
||||
}
|
||||
|
||||
*(HVX_UVector *) (dst_ptr + i1) = Q6_Vsf_equals_Vqf32(acc_vec);
|
||||
i1 += VLEN_FP32;
|
||||
for (uint32_t cb = 0; cb < tile_n; cb += C_TILE) {
|
||||
const uint32_t cb_n = MIN(C_TILE, tile_n - cb);
|
||||
transpose_src0_block(src0_block + cb * src0_stride_inner, ncs, cb_n, d_inner_tile, src0_T, cb);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector acc_vec = Q6_V_vsplat_R(0);
|
||||
for (uint32_t t = 0; t < n_t; ++t) {
|
||||
for (uint32_t cb = 0; cb < tile_n; cb += C_TILE) {
|
||||
const uint32_t cb_n = MIN(C_TILE, tile_n - cb);
|
||||
|
||||
for (uint32_t i0 = 0; i0 < d_conv; ++i0) {
|
||||
uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]);
|
||||
uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc) * sizeof(float);
|
||||
Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets));
|
||||
Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets));
|
||||
HVX_Vector acc = hvx_vec_splat_f32(0.0f);
|
||||
for (uint32_t j = 0; j < d_conv; ++j) {
|
||||
HVX_Vector x = *(const HVX_Vector *) (src0_T + (t + j) * d_inner_tile + cb);
|
||||
HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_per_thread + tile_off + cb);
|
||||
acc = Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(x, w));
|
||||
}
|
||||
HVX_Vector res = Q6_Vsf_equals_Vqf32(acc);
|
||||
|
||||
HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec);
|
||||
acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod);
|
||||
float * dst_ptr = dst_data + i3 * dst_stride_seq + t * dst_stride_token + (ir0 + tile_off + cb);
|
||||
if (cb_n == C_TILE) {
|
||||
*(HVX_UVector *) dst_ptr = res;
|
||||
} else {
|
||||
hvx_vec_store_u(dst_ptr, cb_n * sizeof(float), res);
|
||||
}
|
||||
}
|
||||
|
||||
hvx_vec_store_u(dst_ptr + i1, (ir - i1) * 4, Q6_Vsf_equals_Vqf32(acc_vec));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n",
|
||||
ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1,
|
||||
FARF(HIGH, "ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) tile=%u * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n",
|
||||
ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, d_inner_tile,
|
||||
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1],
|
||||
dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
|
@ -264,46 +358,44 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) {
|
|||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t use_hvx = 0;
|
||||
if (d_inner >= VLEN_FP32 && d_inner % VLEN_FP32 == 0) {
|
||||
int is_aligned = hex_is_aligned((void *) src0->data, VLEN) &&
|
||||
hex_is_aligned((void *) src1->data, VLEN) &&
|
||||
hex_is_aligned((void *) dst->data, VLEN);
|
||||
|
||||
if (is_aligned) {
|
||||
use_hvx = 1;
|
||||
}
|
||||
if (d_inner >= VLEN_FP32 && n_t >= VLEN_FP32) {
|
||||
use_hvx = 1;
|
||||
}
|
||||
|
||||
if (use_hvx) {
|
||||
scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads; // d_inner chunks per thread
|
||||
scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); // round up to even
|
||||
scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads;
|
||||
scctx.nrows_per_thread += (scctx.nrows_per_thread & 1);
|
||||
|
||||
octx->src0_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb01, 256);
|
||||
octx->src1_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb11, 256);
|
||||
octx->dst_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * sizeof(float), 256);
|
||||
const uint32_t d_inner_per_thread = scctx.nrows_per_thread;
|
||||
const uint32_t ncs = src0->ne[0];
|
||||
|
||||
const uint32_t src1_T_size = hex_round_up(d_conv * d_inner_per_thread * sizeof(float), 256);
|
||||
const uint32_t src0_T_max = HTP_SSM_CONV_VTCM_BUDGET > src1_T_size ? HTP_SSM_CONV_VTCM_BUDGET - src1_T_size : 0;
|
||||
|
||||
uint32_t d_inner_tile = (src0_T_max / sizeof(float)) / ncs;
|
||||
d_inner_tile -= (d_inner_tile % VLEN_FP32);
|
||||
if (d_inner_tile == 0) {
|
||||
FARF(HIGH, "ssm_conv-f32: inner tile rounds to 0 (ncs=%u), falling back to scalar\n", ncs);
|
||||
use_hvx = 0;
|
||||
} else {
|
||||
scctx.d_inner_tile = d_inner_tile;
|
||||
|
||||
octx->src0_spad.size_per_thread = hex_round_up(d_inner_tile * ncs * sizeof(float), 256);
|
||||
octx->src1_spad.size_per_thread = src1_T_size;
|
||||
octx->dst_spad.size_per_thread = 0;
|
||||
|
||||
octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads;
|
||||
octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads;
|
||||
octx->dst_spad.size = octx->dst_spad.size_per_thread * n_threads;
|
||||
octx->dst_spad.size = 0;
|
||||
|
||||
// Compute gather scratchpad size for src0 and src1
|
||||
const size_t gather_spad_size = n_threads * VLEN * 2;
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->src0_spad.src = NULL;
|
||||
octx->src1_spad.src = NULL;
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size; octx->src0_spad.src = NULL;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL;
|
||||
|
||||
FARF(HIGH, "ssm_conv-f32: gather-spad:%zu spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\n",
|
||||
gather_spad_size, octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread,
|
||||
octx->dst_spad.size_per_thread, octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size,
|
||||
octx->src0_spad.data, octx->src1_spad.data, octx->dst_spad.data);
|
||||
|
||||
const size_t total_spad_size =
|
||||
gather_spad_size + octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||
|
||||
if (total_spad_size > octx->ctx->vtcm_size) {
|
||||
FARF(HIGH, "ssm_conv-f32: HVX scratchpad size %zu exceeds VTCM size %zu", total_spad_size,
|
||||
octx->ctx->vtcm_size);
|
||||
const size_t total_spad = octx->src0_spad.size + octx->src1_spad.size;
|
||||
if (total_spad > octx->ctx->vtcm_size) {
|
||||
FARF(HIGH, "ssm_conv-f32: scratchpad %zu exceeds VTCM %zu, falling back to scalar\n",
|
||||
total_spad, octx->ctx->vtcm_size);
|
||||
use_hvx = 0;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9337,6 +9337,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||
|
||||
// Examples from granite-4.0-h-1b/ggml-model-Q8_0.gguf
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1})); // prefill
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {937, 8192, 1, 1}, {4, 8192, 1, 1})); // prefill
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1})); // generate
|
||||
test_cases.emplace_back(new test_ssm_conv_bias_silu(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1}, true)); // prefill
|
||||
test_cases.emplace_back(new test_ssm_conv_bias_silu(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1}, true)); // generate
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue