mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-06-01 06:00:36 +00:00
hexagon: add support for CONCAT op (#23648)
Some checks failed
Check Pre-Tokenizer Hashes / pre-tokenizer-hashes (push) Has been cancelled
Python check requirements.txt / check-requirements (push) Has been cancelled
Python Type-Check / python type-check (push) Has been cancelled
Update Operations Documentation / update-ops-docs (push) Has been cancelled
Some checks failed
Check Pre-Tokenizer Hashes / pre-tokenizer-hashes (push) Has been cancelled
Python check requirements.txt / check-requirements (push) Has been cancelled
Python Type-Check / python type-check (push) Has been cancelled
Update Operations Documentation / update-ops-docs (push) Has been cancelled
* hexagon: add support for CONCAT with optimized concat_2d_transposed qwen3.5 models are quite heavy on the CONCAT with large and transposed src1. * hex-concat: use fastdiv in generic version * hex-concat: make checks for transposed a bit more readable * hex-concat: reoder dma ops for better pipelining * hex-cont/cpy: optimize CPY and CONT ops The primary change is to avoid scalar divs in the inner loops. We were calling hvx_copy_uu(... type_size) where type_size is non a constexpr. This causes runtime divs by that value which is normally just 4 or 2 (f32/f16). * hex-get-rows: optimize GET_ROWS for large rows We now use DMA for larger rows and also split them into chunks to improve perf for Qwen3.5 and other models that do lots of GET_ROWS with huge (2MB+ rows). Also bump the DMA queue depth now that we can take advantage of it. * hex-concat: unroll the inner loops of concat_2d * hex-concat: more updates to concat_2d to improve perf a bit further * hex-cpy: fixed n_rows per thread checks in the copy ops * hmx-fa: fix alignment issues while computing dma sizes * hex-set-rows: add early returns for idle threads * hvx-rope: minor optimization to replace loops with fastdiv logic * hex-rope: replace scalar tail processing with HVX * hex-rope: optimize rope cache init with HVX Add hvx-utils sin/cos helpers that use an aprox method (similar to rsqrt, inverse, etc) Use the helpers to optimize ROPE.
This commit is contained in:
parent
678d43d720
commit
ef66bfab68
14 changed files with 867 additions and 230 deletions
|
|
@ -2874,6 +2874,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) {
|
|||
case GGML_OP_NORM: return HTP_OP_NORM;
|
||||
case GGML_OP_L2_NORM: return HTP_OP_L2_NORM;
|
||||
case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM;
|
||||
case GGML_OP_CONCAT: return HTP_OP_CONCAT;
|
||||
case GGML_OP_SCALE: return HTP_OP_SCALE;
|
||||
case GGML_OP_SQR: return HTP_OP_SQR;
|
||||
case GGML_OP_SQRT: return HTP_OP_SQRT;
|
||||
|
|
@ -3286,6 +3287,25 @@ static bool ggml_hexagon_supported_repeat(const struct ggml_hexagon_session * se
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_concat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
int dim = ((const int32_t *) op->op_params)[0];
|
||||
if (dim < 0 || dim >= GGML_MAX_DIMS) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < GGML_MAX_SRC; ++i) {
|
||||
const struct ggml_tensor * src = op->src[i];
|
||||
if (!src) {
|
||||
continue;
|
||||
}
|
||||
if (src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_I32 && src->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_fill(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
|
|
@ -3434,6 +3454,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
|||
supp = ggml_hexagon_supported_cumsum(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_CONCAT:
|
||||
supp = ggml_hexagon_supported_concat(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_FILL:
|
||||
supp = ggml_hexagon_supported_fill(sess, op);
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ add_library(${HTP_LIB} SHARED
|
|||
ssm-conv.c
|
||||
cumsum-ops.c
|
||||
fill-ops.c
|
||||
concat-ops.c
|
||||
diag-ops.c
|
||||
solve-tri-ops.c
|
||||
gated-delta-net-ops.c
|
||||
|
|
|
|||
275
ggml/src/ggml-hexagon/htp/concat-ops.c
Normal file
275
ggml/src/ggml-hexagon/htp/concat-ops.c
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
#include "htp-ctx.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hexagon_types.h"
|
||||
#include "hexagon_protos.h"
|
||||
#include "hvx_hexagon_protos.h"
|
||||
#include "hex-dma.h"
|
||||
#include "vtcm-utils.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "hex-fastdiv.h"
|
||||
#include <string.h>
|
||||
|
||||
struct htp_concat_context {
|
||||
struct htp_ops_context * octx;
|
||||
uint32_t dim;
|
||||
uint32_t nrows_per_thread;
|
||||
struct fastdiv_values div_ne0;
|
||||
struct fastdiv_values div_ne1;
|
||||
struct fastdiv_values div_ne2;
|
||||
};
|
||||
|
||||
static void concat_2d_f32_transposed(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_concat_context * cctx = (struct htp_concat_context *) data;
|
||||
struct htp_ops_context * octx = cctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * src1 = octx->src[1];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
const uint32_t src0_ne0 = src0->ne[0];
|
||||
const uint32_t src1_ne0 = src1->ne[0];
|
||||
const uint32_t ne1 = dst->ne[1];
|
||||
|
||||
const uint32_t start_i = ith * cctx->nrows_per_thread;
|
||||
const uint32_t end_i = (start_i + cctx->nrows_per_thread < ne1) ? (start_i + cctx->nrows_per_thread) : ne1;
|
||||
if (start_i >= end_i) return;
|
||||
|
||||
dma_queue * q = octx->ctx->dma[ith];
|
||||
|
||||
uint8_t * spad0_base = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread;
|
||||
uint8_t * spad1_base = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread;
|
||||
|
||||
const uint32_t block_i = 32;
|
||||
const uint32_t spad1_stride = block_i * sizeof(float);
|
||||
|
||||
int32_t offsets[32] __attribute__((aligned(128)));
|
||||
for(int k=0; k<32; k++) {
|
||||
offsets[k] = k * spad1_stride;
|
||||
}
|
||||
HVX_Vector vv = *(HVX_Vector*)offsets;
|
||||
const uint32_t src1_ne0_padded = hex_round_up(src1_ne0, 32);
|
||||
const uint32_t spad0_row_bytes = hex_round_up((src0_ne0 + src1_ne0_padded) * sizeof(float), VLEN);
|
||||
uint32_t mu = src1_ne0_padded * spad1_stride;
|
||||
|
||||
for (uint32_t i = start_i; i < end_i; i += block_i) {
|
||||
uint32_t current_block_i = (end_i - i < block_i) ? (end_i - i) : block_i;
|
||||
|
||||
uint32_t src1_width_bytes = current_block_i * sizeof(float);
|
||||
uint8_t * src1_ptr = (uint8_t *)src1->data + i * src1->nb[1];
|
||||
dma_queue_push(q, dma_make_ptr(spad1_base, src1_ptr), spad1_stride, src1->nb[0], src1_width_bytes, src1_ne0);
|
||||
|
||||
uint32_t src0_row_bytes = src0_ne0 * sizeof(float);
|
||||
uint8_t * src0_ptr = (uint8_t *)src0->data + i * src0->nb[1];
|
||||
dma_queue_push(q, dma_make_ptr(spad0_base, src0_ptr), spad0_row_bytes, src0->nb[1], src0_row_bytes, current_block_i);
|
||||
|
||||
dma_queue_pop(q); // src1
|
||||
|
||||
HVX_Vector * vtcm_tmp = (HVX_Vector *)(spad1_base + src1_ne0_padded * spad1_stride);
|
||||
|
||||
for (uint32_t j = 0; j < src1_ne0_padded; j += 32) {
|
||||
#pragma unroll(4)
|
||||
for (uint32_t ii = 0; ii < current_block_i; ii++) {
|
||||
size_t rt = (size_t)(spad1_base + j * spad1_stride + ii * sizeof(float));
|
||||
Q6_vgather_ARMVw(&vtcm_tmp[ii], rt, mu, vv);
|
||||
uint8_t * dst_ptr = spad0_base + ii * spad0_row_bytes + (src0_ne0 + j) * sizeof(float);
|
||||
hvx_vmemu(dst_ptr) = vtcm_tmp[ii];
|
||||
}
|
||||
}
|
||||
|
||||
dma_queue_pop(q); // src0
|
||||
|
||||
uint8_t * dst_ptr = (uint8_t *)dst->data + i * dst->nb[1];
|
||||
dma_queue_push(q, dma_make_ptr(dst_ptr, spad0_base), dst->nb[1], spad0_row_bytes, (src0_ne0 + src1_ne0) * sizeof(float), current_block_i);
|
||||
|
||||
dma_queue_pop(q);
|
||||
}
|
||||
}
|
||||
|
||||
static void concat_2d_f16_transposed(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_concat_context * cctx = (struct htp_concat_context *) data;
|
||||
struct htp_ops_context * octx = cctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * src1 = octx->src[1];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
const uint32_t src0_ne0 = src0->ne[0];
|
||||
const uint32_t src1_ne0 = src1->ne[0];
|
||||
const uint32_t ne1 = dst->ne[1];
|
||||
|
||||
const uint32_t start_i = ith * cctx->nrows_per_thread;
|
||||
const uint32_t end_i = (start_i + cctx->nrows_per_thread < ne1) ? (start_i + cctx->nrows_per_thread) : ne1;
|
||||
if (start_i >= end_i) return;
|
||||
|
||||
dma_queue * q = octx->ctx->dma[ith];
|
||||
|
||||
uint8_t * spad0_base = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread;
|
||||
uint8_t * spad1_base = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread;
|
||||
|
||||
const uint32_t block_i = 64;
|
||||
const uint32_t spad1_stride = block_i * sizeof(__fp16);
|
||||
|
||||
int16_t offsets[64] __attribute__((aligned(128)));
|
||||
for(int k=0; k<64; k++) {
|
||||
offsets[k] = k * spad1_stride;
|
||||
}
|
||||
HVX_Vector vv = *(HVX_Vector*)offsets;
|
||||
const uint32_t src1_ne0_padded = hex_round_up(src1_ne0, 64);
|
||||
const uint32_t spad0_row_bytes = hex_round_up((src0_ne0 + src1_ne0_padded) * sizeof(__fp16), VLEN);
|
||||
uint32_t mu = src1_ne0_padded * spad1_stride;
|
||||
|
||||
for (uint32_t i = start_i; i < end_i; i += block_i) {
|
||||
uint32_t current_block_i = (end_i - i < block_i) ? (end_i - i) : block_i;
|
||||
|
||||
uint32_t src1_width_bytes = current_block_i * sizeof(__fp16);
|
||||
uint8_t * src1_ptr = (uint8_t *)src1->data + i * src1->nb[1];
|
||||
dma_queue_push(q, dma_make_ptr(spad1_base, src1_ptr), spad1_stride, src1->nb[0], src1_width_bytes, src1_ne0);
|
||||
|
||||
uint32_t src0_row_bytes = src0_ne0 * sizeof(__fp16);
|
||||
uint8_t * src0_ptr = (uint8_t *)src0->data + i * src0->nb[1];
|
||||
dma_queue_push(q, dma_make_ptr(spad0_base, src0_ptr), spad0_row_bytes, src0->nb[1], src0_row_bytes, current_block_i);
|
||||
|
||||
dma_queue_pop(q); // src1
|
||||
|
||||
HVX_Vector * vtcm_tmp = (HVX_Vector *)(spad1_base + src1_ne0_padded * spad1_stride);
|
||||
|
||||
for (uint32_t j = 0; j < src1_ne0_padded; j += 64) {
|
||||
#pragma unroll(4)
|
||||
for (uint32_t ii = 0; ii < current_block_i; ii++) {
|
||||
size_t rt = (size_t)(spad1_base + j * spad1_stride + ii * sizeof(__fp16));
|
||||
Q6_vgather_ARMVh(&vtcm_tmp[ii], rt, mu, vv);
|
||||
uint8_t * dst_ptr = spad0_base + ii * spad0_row_bytes + (src0_ne0 + j) * sizeof(__fp16);
|
||||
hvx_vmemu(dst_ptr) = vtcm_tmp[ii];
|
||||
}
|
||||
}
|
||||
|
||||
dma_queue_pop(q); // src0
|
||||
|
||||
uint8_t * dst_ptr = (uint8_t *)dst->data + i * dst->nb[1];
|
||||
dma_queue_push(q, dma_make_ptr(dst_ptr, spad0_base), dst->nb[1], spad0_row_bytes, (src0_ne0 + src1_ne0) * sizeof(__fp16), current_block_i);
|
||||
|
||||
dma_queue_pop(q);
|
||||
}
|
||||
}
|
||||
|
||||
static void concat_generic(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_concat_context * cctx = (struct htp_concat_context *) data;
|
||||
struct htp_ops_context * octx = cctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * src1 = octx->src[1];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
const int dim = cctx->dim;
|
||||
const uint32_t type_size = (dst->type == HTP_TYPE_F32 || dst->type == HTP_TYPE_I32) ? 4 : 2;
|
||||
|
||||
const uint32_t ne[4] = {dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]};
|
||||
const uint32_t total_elements = ne[0] * ne[1] * ne[2] * ne[3];
|
||||
const uint32_t chunk_size = (total_elements + nth - 1) / nth;
|
||||
|
||||
const uint32_t start_idx = MIN(ith * chunk_size, total_elements);
|
||||
const uint32_t end_idx = MIN(start_idx + chunk_size, total_elements);
|
||||
|
||||
// Naive scalar element-wise copy
|
||||
for (uint32_t idx = start_idx; idx < end_idx; idx++) {
|
||||
uint32_t idx_div_ne0 = fastdiv(idx, &cctx->div_ne0);
|
||||
uint32_t i0 = idx - idx_div_ne0 * ne[0];
|
||||
|
||||
uint32_t idx_div_ne01 = fastdiv(idx_div_ne0, &cctx->div_ne1);
|
||||
uint32_t i1 = idx_div_ne0 - idx_div_ne01 * ne[1];
|
||||
|
||||
uint32_t idx_div_ne012 = fastdiv(idx_div_ne01, &cctx->div_ne2);
|
||||
uint32_t i2 = idx_div_ne01 - idx_div_ne012 * ne[2];
|
||||
uint32_t i3 = idx_div_ne012;
|
||||
|
||||
uint8_t * dst_ptr = (uint8_t *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2] + i1 * dst->nb[1] + i0 * dst->nb[0];
|
||||
|
||||
uint32_t idx_dim = 0;
|
||||
if (dim == 0) idx_dim = i0;
|
||||
else if (dim == 1) idx_dim = i1;
|
||||
else if (dim == 2) idx_dim = i2;
|
||||
else if (dim == 3) idx_dim = i3;
|
||||
|
||||
const struct htp_tensor * src = (idx_dim < src0->ne[dim]) ? src0 : src1;
|
||||
|
||||
uint32_t s0 = i0;
|
||||
uint32_t s1 = i1;
|
||||
uint32_t s2 = i2;
|
||||
uint32_t s3 = i3;
|
||||
|
||||
if (dim == 0 && src == src1) s0 -= src0->ne[0];
|
||||
if (dim == 1 && src == src1) s1 -= src0->ne[1];
|
||||
if (dim == 2 && src == src1) s2 -= src0->ne[2];
|
||||
if (dim == 3 && src == src1) s3 -= src0->ne[3];
|
||||
|
||||
uint8_t * src_ptr = (uint8_t *)src->data + s3 * src->nb[3] + s2 * src->nb[2] + s1 * src->nb[1] + s0 * src->nb[0];
|
||||
|
||||
if (type_size == 4) {
|
||||
*(float*)dst_ptr = *(float*)src_ptr;
|
||||
} else {
|
||||
*(__fp16*)dst_ptr = *(__fp16*)src_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int op_concat(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * src1 = octx->src[1];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
int dim = octx->op_params[0];
|
||||
|
||||
bool is_2d = dst->ne[2] == 1 && dst->ne[3] == 1;
|
||||
|
||||
const uint32_t type_size = (dst->type == HTP_TYPE_F32 || dst->type == HTP_TYPE_I32) ? 4 : 2;
|
||||
bool is_src1_transposed = (src1->nb[0] > src1->nb[1]);
|
||||
bool is_src0_transposed = (src0->nb[0] > src0->nb[1]);
|
||||
|
||||
uint32_t n_threads = octx->n_threads;
|
||||
struct htp_concat_context cctx;
|
||||
cctx.octx = octx;
|
||||
cctx.dim = dim;
|
||||
cctx.div_ne0 = init_fastdiv_values(dst->ne[0]);
|
||||
cctx.div_ne1 = init_fastdiv_values(dst->ne[1]);
|
||||
cctx.div_ne2 = init_fastdiv_values(dst->ne[2]);
|
||||
|
||||
void (*worker_func)(unsigned int, unsigned int, void *) = concat_generic;
|
||||
|
||||
if (dim == 0 && is_2d && is_src1_transposed && !is_src0_transposed) {
|
||||
n_threads = MIN(dst->ne[1], n_threads);
|
||||
if (n_threads < 1) {
|
||||
n_threads = 1;
|
||||
}
|
||||
uint32_t block_i = (type_size == 4) ? 32 : 64;
|
||||
|
||||
cctx.nrows_per_thread = hmx_ceil_div(dst->ne[1], n_threads);
|
||||
|
||||
// Allocate VTCM
|
||||
uint32_t spad1_stride = block_i * type_size;
|
||||
|
||||
uint32_t src1_ne0_padded = hex_round_up(src1->ne[0], block_i);
|
||||
uint32_t spad0_row_bytes = hex_round_up((src0->ne[0] + src1_ne0_padded) * type_size, VLEN);
|
||||
|
||||
octx->src0_spad.size_per_thread = block_i * spad0_row_bytes;
|
||||
octx->src1_spad.size_per_thread = src1_ne0_padded * spad1_stride + block_i * VLEN;
|
||||
|
||||
octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
|
||||
octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread;
|
||||
|
||||
if (octx->src0_spad.size + octx->src1_spad.size > octx->ctx->vtcm_size) {
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
|
||||
if (type_size == 4) {
|
||||
worker_func = concat_2d_f32_transposed;
|
||||
} else {
|
||||
worker_func = concat_2d_f16_transposed;
|
||||
}
|
||||
}
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, worker_func, &cctx, n_threads);
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
|
@ -28,158 +28,170 @@ struct htp_copy_context {
|
|||
uint32_t dst_blocks_per_row;
|
||||
|
||||
uint32_t src0_nrows_per_thread;
|
||||
|
||||
void (*copy)(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith);
|
||||
};
|
||||
|
||||
#define cpy_preamble \
|
||||
const struct htp_tensor *src0 = octx->src[0]; \
|
||||
const struct htp_tensor *dst = octx->dst; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3]; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3]; \
|
||||
\
|
||||
const uint32_t nr = ne01;
|
||||
|
||||
static void cpy_thread_sametype_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
cpy_preamble;
|
||||
|
||||
// parallelize by src0 rows
|
||||
const uint32_t dr = ct->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
|
||||
|
||||
// copy by rows
|
||||
for (uint32_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; i02++) {
|
||||
#pragma unroll(2)
|
||||
for (uint32_t i01 = ir0; i01 < ir1; i01++) {
|
||||
uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
|
||||
uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
||||
hex_l2fetch(src0_ptr, ne00 * ct->src0_type_size, nb01, 2);
|
||||
hvx_copy_uu(dst_ptr, src0_ptr, ne00, ct->src0_type_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
#define DEFINE_CPY_SAMESHAPE(NAME, ELEM_TYPE, ELEM_SIZE) \
|
||||
static void cpy_thread_##NAME##_sameshape(unsigned int nth, unsigned int ith, void * data) { \
|
||||
struct htp_copy_context * ct = (struct htp_copy_context *) data; \
|
||||
struct htp_ops_context * octx = ct->octx; \
|
||||
cpy_preamble; \
|
||||
const uint32_t dr = ct->src0_nrows_per_thread; \
|
||||
const uint32_t ir0 = dr * ith; \
|
||||
const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; \
|
||||
if (ir0 >= nr) return; \
|
||||
for (uint32_t i03 = 0; i03 < ne03; i03++) { \
|
||||
for (uint32_t i02 = 0; i02 < ne02; i02++) { \
|
||||
_Pragma("unroll(4)") \
|
||||
for (uint32_t i01 = ir0; i01 < ir1; i01++) { \
|
||||
uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3; \
|
||||
uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03; \
|
||||
hex_l2fetch(src0_ptr, ne00 * ELEM_SIZE, nb01, 2); \
|
||||
hvx_copy_uu(dst_ptr, src0_ptr, ne00, ELEM_SIZE); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
static void cpy_thread_sametype_reshape(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith) {
|
||||
cpy_preamble;
|
||||
DEFINE_CPY_SAMESHAPE(f32, float, 4)
|
||||
DEFINE_CPY_SAMESHAPE(f16, __fp16, 2)
|
||||
|
||||
// parallelize by src0 rows
|
||||
const uint32_t dr = ct->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
|
||||
|
||||
// Fast path: when both src0 and dst are contiguous in memory
|
||||
// Replace the element-by-element loop with a single bulk HVX copy per (i03, i02) slice.
|
||||
const bool src0_contig = (nb00 == ct->src0_type_size) &&
|
||||
(nb01 == ne00 * nb00) &&
|
||||
(nb02 == ne01 * nb01) &&
|
||||
(nb03 == ne02 * nb02);
|
||||
const bool dst_contig = (nb0 == ct->dst_type_size) &&
|
||||
(nb1 == ne0 * nb0) &&
|
||||
(nb2 == ne1 * nb1) &&
|
||||
(nb3 == ne2 * nb2);
|
||||
|
||||
if (src0_contig && dst_contig) {
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
uint8_t * src_ptr = (uint8_t *) src0->data + i03*nb03 + i02*nb02 + ir0*nb01;
|
||||
uint32_t flat = ((i03*ne02 + i02)*ne01 + ir0) * ne00;
|
||||
uint8_t * dst_ptr = (uint8_t *) dst->data + flat * ct->src0_type_size;
|
||||
hvx_copy_uu(dst_ptr, src_ptr, (ir1 - ir0) * ne00, ct->src0_type_size);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// dst counters
|
||||
int64_t k10 = 0;
|
||||
int64_t i11 = 0;
|
||||
int64_t i12 = 0;
|
||||
int64_t i13 = 0;
|
||||
|
||||
// number of blocks in a row
|
||||
const int64_t nk00 = ct->src0_blocks_per_row;
|
||||
const int64_t nk0 = ct->dst_blocks_per_row;
|
||||
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
k10 += nk00 * ir0;
|
||||
while (k10 >= nk0) {
|
||||
k10 -= nk0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
||||
for (int64_t k00 = 0; k00 < nk00; k00++) {
|
||||
const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
||||
memcpy(dst_ptr, src0_ptr, ct->dst_type_size);
|
||||
|
||||
if (++k10 == nk0) {
|
||||
k10 = 0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
k10 += nk00 * (ne01 - ir1);
|
||||
while (k10 >= nk0) {
|
||||
k10 -= nk0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#define DEFINE_CPY_RESHAPE(NAME, ELEM_TYPE, ELEM_SIZE) \
|
||||
static void cpy_thread_##NAME##_reshape(unsigned int nth, unsigned int ith, void * data) { \
|
||||
struct htp_copy_context * ct = (struct htp_copy_context *) data; \
|
||||
struct htp_ops_context * octx = ct->octx; \
|
||||
cpy_preamble; \
|
||||
const uint32_t dr = ct->src0_nrows_per_thread; \
|
||||
const uint32_t ir0 = dr * ith; \
|
||||
const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; \
|
||||
if (ir0 >= nr) return; \
|
||||
const bool src0_contig = (nb00 == ELEM_SIZE) && \
|
||||
(nb01 == ne00 * nb00) && \
|
||||
(nb02 == ne01 * nb01) && \
|
||||
(nb03 == ne02 * nb02); \
|
||||
const bool dst_contig = (nb0 == ELEM_SIZE) && \
|
||||
(nb1 == ne0 * nb0) && \
|
||||
(nb2 == ne1 * nb1) && \
|
||||
(nb3 == ne2 * nb2); \
|
||||
if (src0_contig && dst_contig) { \
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) { \
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) { \
|
||||
uint8_t * src_ptr = (uint8_t *) src0->data + i03*nb03 + i02*nb02 + ir0*nb01; \
|
||||
uint32_t flat = ((i03*ne02 + i02)*ne01 + ir0) * ne00; \
|
||||
uint8_t * dst_ptr = (uint8_t *) dst->data + flat * ELEM_SIZE; \
|
||||
hvx_copy_uu(dst_ptr, src_ptr, (ir1 - ir0) * ne00, ELEM_SIZE); \
|
||||
} \
|
||||
} \
|
||||
return; \
|
||||
} \
|
||||
const bool reshape_flat_fast = (ne03 == 1 && ne2 == 1 && ne3 == 1) && \
|
||||
(ne0 == ne00 * ne01) && (ne1 == ne02) && \
|
||||
(nb00 == ELEM_SIZE) && (nb0 == ELEM_SIZE); \
|
||||
if (reshape_flat_fast) { \
|
||||
for (uint32_t i02 = 0; i02 < ne02; i02++) { \
|
||||
for (uint32_t i01 = ir0; i01 < ir1; i01++) { \
|
||||
uint8_t * src0_ptr = (uint8_t *) src0->data + i01 * nb01 + i02 * nb02; \
|
||||
uint8_t * dst_ptr = (uint8_t *) dst->data + i01 * ne00 * ELEM_SIZE + i02 * nb1; \
|
||||
hvx_copy_uu(dst_ptr, src0_ptr, ne00, ELEM_SIZE); \
|
||||
} \
|
||||
} \
|
||||
return; \
|
||||
} \
|
||||
int64_t k10 = 0; \
|
||||
int64_t i11 = 0; \
|
||||
int64_t i12 = 0; \
|
||||
int64_t i13 = 0; \
|
||||
const int64_t nk00 = ct->src0_blocks_per_row; \
|
||||
const int64_t nk0 = ct->dst_blocks_per_row; \
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) { \
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) { \
|
||||
k10 += nk00 * ir0; \
|
||||
while (k10 >= nk0) { \
|
||||
k10 -= nk0; \
|
||||
if (++i11 == ne1) { \
|
||||
i11 = 0; \
|
||||
if (++i12 == ne2) { \
|
||||
i12 = 0; \
|
||||
if (++i13 == ne3) { \
|
||||
i13 = 0; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
for (int64_t i01 = ir0; i01 < ir1; i01++) { \
|
||||
for (int64_t k00 = 0; k00 < nk00; k00++) { \
|
||||
const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); \
|
||||
char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); \
|
||||
memcpy(dst_ptr, src0_ptr, ELEM_SIZE); \
|
||||
if (++k10 == nk0) { \
|
||||
k10 = 0; \
|
||||
if (++i11 == ne1) { \
|
||||
i11 = 0; \
|
||||
if (++i12 == ne2) { \
|
||||
i12 = 0; \
|
||||
if (++i13 == ne3) { \
|
||||
i13 = 0; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
k10 += nk00 * (ne01 - ir1); \
|
||||
while (k10 >= nk0) { \
|
||||
k10 -= nk0; \
|
||||
if (++i11 == ne1) { \
|
||||
i11 = 0; \
|
||||
if (++i12 == ne2) { \
|
||||
i12 = 0; \
|
||||
if (++i13 == ne3) { \
|
||||
i13 = 0; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
static void cpy_thread_f16_f32_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
DEFINE_CPY_RESHAPE(f32, float, 4)
|
||||
DEFINE_CPY_RESHAPE(f16, __fp16, 2)
|
||||
|
||||
static void cpy_thread_f16_f32_sameshape(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_copy_context * ct = (struct htp_copy_context *) data;
|
||||
struct htp_ops_context * octx = ct->octx;
|
||||
cpy_preamble;
|
||||
|
||||
// parallelize by src0 rows
|
||||
const uint32_t dr = ct->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
|
||||
if (ir0 >= nr) return;
|
||||
|
||||
// copy by rows
|
||||
for (uint32_t i03 = 0; i03 < ne03; i03++) {
|
||||
|
|
@ -195,13 +207,16 @@ static void cpy_thread_f16_f32_sameshape(struct htp_copy_context * ct, struct ht
|
|||
}
|
||||
}
|
||||
|
||||
static void cpy_thread_f32_f16_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
static void cpy_thread_f32_f16_sameshape(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_copy_context * ct = (struct htp_copy_context *) data;
|
||||
struct htp_ops_context * octx = ct->octx;
|
||||
cpy_preamble;
|
||||
|
||||
// parallelize by src0 rows
|
||||
const uint32_t dr = ct->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
|
||||
if (ir0 >= nr) return;
|
||||
|
||||
// copy by rows
|
||||
for (uint32_t i03 = 0; i03 < ne03; i03++) {
|
||||
|
|
@ -217,11 +232,6 @@ static void cpy_thread_f32_f16_sameshape(struct htp_copy_context * ct, struct ht
|
|||
}
|
||||
}
|
||||
|
||||
static void cpy_work_func(unsigned int n, unsigned int i, void *data) {
|
||||
struct htp_copy_context *ct = (struct htp_copy_context *) data;
|
||||
ct->copy(ct, ct->octx, n, i);
|
||||
}
|
||||
|
||||
int op_cpy(struct htp_ops_context * octx) {
|
||||
cpy_preamble;
|
||||
|
||||
|
|
@ -254,22 +264,32 @@ int op_cpy(struct htp_ops_context * octx) {
|
|||
|
||||
ct.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads;
|
||||
|
||||
worker_callback_t copy_fun;
|
||||
|
||||
if (sametype && sameshape) {
|
||||
ct.copy = cpy_thread_sametype_sameshape;
|
||||
if (src0->type == HTP_TYPE_F32) {
|
||||
copy_fun = cpy_thread_f32_sameshape;
|
||||
} else {
|
||||
copy_fun = cpy_thread_f16_sameshape;
|
||||
}
|
||||
} else if (sameshape) {
|
||||
/**/ if (dst->type == HTP_TYPE_F16 && src0->type == HTP_TYPE_F32)
|
||||
ct.copy = cpy_thread_f16_f32_sameshape;
|
||||
copy_fun = cpy_thread_f16_f32_sameshape;
|
||||
else if (dst->type == HTP_TYPE_F32 && src0->type == HTP_TYPE_F16)
|
||||
ct.copy = cpy_thread_f32_f16_sameshape;
|
||||
copy_fun = cpy_thread_f32_f16_sameshape;
|
||||
else
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
} else if (sametype) {
|
||||
ct.copy = cpy_thread_sametype_reshape;
|
||||
if (src0->type == HTP_TYPE_F32) {
|
||||
copy_fun = cpy_thread_f32_reshape;
|
||||
} else {
|
||||
copy_fun = cpy_thread_f16_reshape;
|
||||
}
|
||||
} else {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_threads);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, copy_fun, &ct, n_threads);
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,9 +17,13 @@
|
|||
|
||||
struct get_rows_context {
|
||||
struct htp_ops_context * octx;
|
||||
uint32_t src1_nrows_per_thread;
|
||||
uint32_t tasks_per_thread;
|
||||
uint32_t total_tasks;
|
||||
uint32_t chunks_per_row;
|
||||
uint32_t chunk_size;
|
||||
struct fastdiv_values get_rows_div_ne10;
|
||||
struct fastdiv_values get_rows_div_ne10_ne11;
|
||||
struct fastdiv_values get_rows_div_chunks_per_row;
|
||||
};
|
||||
|
||||
#define get_rows_preamble \
|
||||
|
|
@ -52,20 +56,23 @@ struct get_rows_context {
|
|||
\
|
||||
const uint32_t nr = ne10 * ne11 * ne12;
|
||||
|
||||
static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
|
||||
static void get_rows_thread_f32_f32_dma(unsigned int nth, unsigned int ith, void *data) {
|
||||
struct get_rows_context * grctx = (struct get_rows_context *)data;
|
||||
struct htp_ops_context * octx = grctx->octx;
|
||||
get_rows_preamble;
|
||||
|
||||
uint64_t qt = HAP_perf_get_qtimer_count();
|
||||
|
||||
// parallelize by src1 elements (which correspond to dst rows)
|
||||
const uint32_t dr = grctx->src1_nrows_per_thread;
|
||||
const uint32_t dr = grctx->tasks_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
if (ir0 >= grctx->total_tasks) {
|
||||
return;
|
||||
}
|
||||
const uint32_t ir1 = MIN(ir0 + dr, grctx->total_tasks);
|
||||
|
||||
const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32);
|
||||
|
||||
dma_queue * dma_queue = octx->ctx->dma[ith];
|
||||
for (uint32_t i = ir0; i < ir1; ++i) {
|
||||
const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11);
|
||||
const uint32_t rem = i - i12 * ne11 * ne10;
|
||||
|
|
@ -73,29 +80,77 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da
|
|||
const uint32_t i10 = rem - i11 * ne10;
|
||||
|
||||
const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
|
||||
uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
|
||||
|
||||
if (i01 >= ne01) {
|
||||
// invalid index, skip for now to avoid crash
|
||||
continue;
|
||||
}
|
||||
|
||||
const uintptr_t src0_ptr = octx->src[0]->data + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
const uintptr_t dst_ptr = octx->dst->data + i10*nb1 + i11*nb2 + i12*nb3;
|
||||
hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
|
||||
|
||||
while (!dma_queue_push(dma_queue, dma_make_ptr((void *)dst_ptr, (const void *)src0_ptr), nb1, nb01, ne00 * sizeof(float), 1)) {
|
||||
dma_queue_pop(dma_queue);
|
||||
}
|
||||
}
|
||||
dma_queue_flush(dma_queue);
|
||||
|
||||
qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt);
|
||||
FARF(HIGH, "get-rows-f32-f32-dma %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||
ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt);
|
||||
}
|
||||
|
||||
static void get_rows_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) {
|
||||
struct get_rows_context * grctx = (struct get_rows_context *)data;
|
||||
struct htp_ops_context * octx = grctx->octx;
|
||||
get_rows_preamble;
|
||||
|
||||
uint64_t qt = HAP_perf_get_qtimer_count();
|
||||
|
||||
const uint32_t dr = grctx->tasks_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
if (ir0 >= grctx->total_tasks) {
|
||||
return;
|
||||
}
|
||||
const uint32_t ir1 = MIN(ir0 + dr, grctx->total_tasks);
|
||||
|
||||
const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32);
|
||||
|
||||
const uint32_t chunks_per_row = grctx->chunks_per_row;
|
||||
const uint32_t chunk_size = grctx->chunk_size;
|
||||
for (uint32_t i = ir0; i < ir1; ++i) {
|
||||
const uint32_t row_idx = fastdiv(i, &grctx->get_rows_div_chunks_per_row);
|
||||
const uint32_t chunk_idx = i - row_idx * chunks_per_row;
|
||||
|
||||
const uint32_t i12 = fastdiv(row_idx, &grctx->get_rows_div_ne10_ne11);
|
||||
const uint32_t rem = row_idx - i12 * ne11 * ne10;
|
||||
const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10);
|
||||
const uint32_t i10 = rem - i11 * ne10;
|
||||
|
||||
const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
|
||||
|
||||
if (i01 >= ne01) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint32_t offset = chunk_idx * chunk_size;
|
||||
if (offset < ne00) {
|
||||
const uint32_t copy_size = MIN(chunk_size, ne00 - offset);
|
||||
const uintptr_t src0_ptr = octx->src[0]->data + i01*nb01 + i11*nb02 + i12*nb03 + offset * sizeof(float);
|
||||
const uintptr_t dst_ptr = octx->dst->data + i10*nb1 + i11*nb2 + i12*nb3 + offset * sizeof(float);
|
||||
hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, copy_size);
|
||||
}
|
||||
}
|
||||
|
||||
qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt);
|
||||
FARF(HIGH, "get-rows-f32-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||
FARF(HIGH, "get-rows-f32-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||
ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt);
|
||||
}
|
||||
|
||||
int op_get_rows(struct htp_ops_context * octx) {
|
||||
get_rows_preamble;
|
||||
|
||||
const uint32_t n_threads = MIN(nr, octx->n_threads);
|
||||
|
||||
if (octx->src[0]->type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
|
@ -112,13 +167,52 @@ int op_get_rows(struct htp_ops_context * octx) {
|
|||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
const uint32_t nb00 = octx->src[0]->nb[0];
|
||||
const uint32_t nb0 = octx->dst->nb[0];
|
||||
|
||||
const bool can_use_dma = (nb00 == sizeof(float)) && (nb0 == sizeof(float));
|
||||
const bool use_dma = can_use_dma && (ne00 >= 2048);
|
||||
|
||||
struct get_rows_context grctx;
|
||||
grctx.octx = octx;
|
||||
grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src[1]->ne[0]);
|
||||
grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src[1]->ne[0] * octx->src[1]->ne[1]);
|
||||
|
||||
grctx.src1_nrows_per_thread = (nr + n_threads - 1) / n_threads;
|
||||
if (use_dma) {
|
||||
grctx.chunks_per_row = 1;
|
||||
grctx.chunk_size = ne00;
|
||||
grctx.total_tasks = nr;
|
||||
grctx.get_rows_div_chunks_per_row = init_fastdiv_values(1);
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_threads);
|
||||
const uint32_t n_threads = MIN(nr, octx->n_threads);
|
||||
grctx.tasks_per_thread = (nr + n_threads - 1) / n_threads;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32_dma, &grctx, n_threads);
|
||||
} else {
|
||||
uint32_t chunks_per_row = 1;
|
||||
uint32_t chunk_size = ne00;
|
||||
uint32_t total_tasks = nr;
|
||||
|
||||
if (nr < octx->n_threads) {
|
||||
const uint32_t min_chunk_size = 1024;
|
||||
uint32_t max_chunks = ne00 / min_chunk_size;
|
||||
if (max_chunks == 0) {
|
||||
max_chunks = 1;
|
||||
}
|
||||
chunks_per_row = MIN((octx->n_threads + nr - 1) / nr, max_chunks);
|
||||
chunk_size = (ne00 + chunks_per_row - 1) / chunks_per_row;
|
||||
total_tasks = nr * chunks_per_row;
|
||||
}
|
||||
|
||||
grctx.chunks_per_row = chunks_per_row;
|
||||
grctx.chunk_size = chunk_size;
|
||||
grctx.total_tasks = total_tasks;
|
||||
grctx.get_rows_div_chunks_per_row = init_fastdiv_values(chunks_per_row);
|
||||
|
||||
const uint32_t n_threads = MIN(total_tasks, octx->n_threads);
|
||||
grctx.tasks_per_thread = (total_tasks + n_threads - 1) / n_threads;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32_hvx, &grctx, n_threads);
|
||||
}
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -50,8 +50,8 @@ static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV,
|
|||
const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS);
|
||||
const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK]
|
||||
const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong
|
||||
const size_t k_dma_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K DMA: [Bc, DK] x2 double-buf
|
||||
const size_t v_dma_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V DMA: [Bc, DV] x2 double-buf
|
||||
const size_t k_dma_size = hex_align_up(Bc * hex_round_up(DK * sizeof(__fp16), 128), 4096); // K DMA: [Bc, DK] x2 double-buf
|
||||
const size_t v_dma_size = hex_align_up(Bc * hex_round_up(DV * sizeof(__fp16), 128), 4096); // V DMA: [Bc, DV] x2 double-buf
|
||||
const size_t k_tile_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K tiles: [Bc, DK] interleaved
|
||||
const size_t v_tile_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V tiles: [Bc, DV] interleaved
|
||||
const size_t s_tile_size = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); // S/P:[g_br, Bc]
|
||||
|
|
@ -1278,7 +1278,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
|||
struct hmx_fa_context factx;
|
||||
memset(&factx, 0, sizeof(factx));
|
||||
factx.octx = octx;
|
||||
factx.n_threads = octx->ctx->n_threads;
|
||||
factx.n_threads = n_threads;
|
||||
factx.DK = DK;
|
||||
factx.DV = DV;
|
||||
factx.n_kv = nek1;
|
||||
|
|
@ -1328,10 +1328,15 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
|||
factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2);
|
||||
|
||||
// ======== VTCM allocation (GQA-aware) ========
|
||||
const size_t size_k_row = DK * sizeof(__fp16);
|
||||
const size_t size_v_row = DV * sizeof(__fp16);
|
||||
const size_t size_k_row_padded = hex_round_up(size_k_row, 128);
|
||||
const size_t size_v_row_padded = hex_round_up(size_v_row, 128);
|
||||
|
||||
const size_t q_tile_bytes = hex_align_up(g_br * DK * sizeof(__fp16), 4096);
|
||||
const size_t o_tile_bytes = hex_align_up(g_br * DV * sizeof(__fp16), 4096);
|
||||
const size_t k_dma_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096);
|
||||
const size_t v_dma_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096);
|
||||
const size_t k_dma_bytes = hex_align_up(Bc * size_k_row_padded, 4096);
|
||||
const size_t v_dma_bytes = hex_align_up(Bc * size_v_row_padded, 4096);
|
||||
const size_t k_tile_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096);
|
||||
const size_t v_tile_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096);
|
||||
const size_t s_tile_bytes = hex_align_up(g_br * Bc * sizeof(__fp16), 4096);
|
||||
|
|
@ -1401,11 +1406,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
|||
// ======== DMA setup ========
|
||||
dma_queue * const dma = ctx->dma[0];
|
||||
|
||||
// Padded row sizes for DMA
|
||||
const size_t size_k_row = nek0 * sizeof(__fp16);
|
||||
const size_t size_v_row = nev0 * sizeof(__fp16);
|
||||
const size_t size_k_row_padded = hex_round_up(nek0 * sizeof(__fp16), 128);
|
||||
const size_t size_v_row_padded = hex_round_up(nev0 * sizeof(__fp16), 128);
|
||||
// Padded row sizes for DMA (defined in outer scope)
|
||||
|
||||
const size_t n_row_tiles_g_br = g_br / HMX_FP16_TILE_N_ROWS;
|
||||
const size_t n_tiles_per_bc = Bc / HMX_FP16_TILE_N_COLS;
|
||||
|
|
|
|||
|
|
@ -104,6 +104,7 @@ int op_argsort(struct htp_ops_context * octx);
|
|||
int op_ssm_conv(struct htp_ops_context * octx);
|
||||
int op_cumsum(struct htp_ops_context * octx);
|
||||
int op_fill(struct htp_ops_context * octx);
|
||||
int op_concat(struct htp_ops_context * octx);
|
||||
int op_diag(struct htp_ops_context * octx);
|
||||
int op_solve_tri(struct htp_ops_context * octx);
|
||||
int op_gated_delta_net(struct htp_ops_context * octx);
|
||||
|
|
|
|||
|
|
@ -89,6 +89,7 @@ enum htp_op_code {
|
|||
HTP_OP_TRI,
|
||||
HTP_OP_PAD,
|
||||
HTP_OP_NORM,
|
||||
HTP_OP_CONCAT,
|
||||
|
||||
HTP_OP_INVALID
|
||||
};
|
||||
|
|
|
|||
90
ggml/src/ggml-hexagon/htp/hvx-sin-cos.h
Normal file
90
ggml/src/ggml-hexagon/htp/hvx-sin-cos.h
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
#ifndef HVX_SIN_COS_H
|
||||
#define HVX_SIN_COS_H
|
||||
|
||||
#include "hvx-base.h"
|
||||
#include "hvx-floor.h"
|
||||
|
||||
static inline HVX_Vector hvx_vec_cos_f32(HVX_Vector x) {
|
||||
HVX_Vector const_inv_pi = hvx_vec_splat_f32(0.3183098861837907f);
|
||||
HVX_Vector const_half = hvx_vec_splat_f32(0.5f);
|
||||
HVX_Vector const_pi = hvx_vec_splat_f32(3.141592653589793f);
|
||||
HVX_Vector const_one = hvx_vec_splat_f32(1.0f);
|
||||
HVX_Vector const_neg_one = hvx_vec_splat_f32(-1.0f);
|
||||
|
||||
// n = floor(x * (1/pi) + 0.5)
|
||||
HVX_Vector n_float = hvx_vec_floor_f32(hvx_vec_add_f32_f32(hvx_vec_mul_f32_f32(x, const_inv_pi), const_half));
|
||||
|
||||
// y = x - n * pi
|
||||
HVX_Vector y = hvx_vec_sub_f32_f32(x, hvx_vec_mul_f32_f32(n_float, const_pi));
|
||||
|
||||
// Sign determination: if n is odd, sign is -1.0f, else 1.0f
|
||||
// half_n = n * 0.5f
|
||||
HVX_Vector half_n = hvx_vec_mul_f32_f32(n_float, const_half);
|
||||
// floor_half_n = floor(half_n)
|
||||
HVX_Vector floor_half_n = hvx_vec_floor_f32(half_n);
|
||||
// is_odd = half_n > floor_half_n
|
||||
HVX_VectorPred is_odd = Q6_Q_vcmp_gt_VsfVsf(half_n, floor_half_n);
|
||||
// sign = vmux(is_odd, -1.0f, 1.0f)
|
||||
HVX_Vector sign = Q6_V_vmux_QVV(is_odd, const_neg_one, const_one);
|
||||
|
||||
// z = y^2
|
||||
HVX_Vector z = hvx_vec_mul_f32_f32(y, y);
|
||||
|
||||
// Chebyshev approximation for cos(y)
|
||||
HVX_Vector c4 = hvx_vec_splat_f32(2.3557242013849433e-05f);
|
||||
HVX_Vector c3 = hvx_vec_splat_f32(-0.0013871428263450528f);
|
||||
HVX_Vector c2 = hvx_vec_splat_f32(0.041665895266688284f);
|
||||
HVX_Vector c1 = hvx_vec_splat_f32(-0.4999999360426369f);
|
||||
HVX_Vector c0 = hvx_vec_splat_f32(0.9999999999071725f);
|
||||
|
||||
HVX_Vector cos_y = hvx_vec_add_f32_f32(c3, hvx_vec_mul_f32_f32(z, c4));
|
||||
cos_y = hvx_vec_add_f32_f32(c2, hvx_vec_mul_f32_f32(z, cos_y));
|
||||
cos_y = hvx_vec_add_f32_f32(c1, hvx_vec_mul_f32_f32(z, cos_y));
|
||||
cos_y = hvx_vec_add_f32_f32(c0, hvx_vec_mul_f32_f32(z, cos_y));
|
||||
|
||||
return hvx_vec_mul_f32_f32(cos_y, sign);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_sin_f32(HVX_Vector x) {
|
||||
HVX_Vector const_inv_pi = hvx_vec_splat_f32(0.3183098861837907f);
|
||||
HVX_Vector const_half = hvx_vec_splat_f32(0.5f);
|
||||
HVX_Vector const_pi = hvx_vec_splat_f32(3.141592653589793f);
|
||||
HVX_Vector const_one = hvx_vec_splat_f32(1.0f);
|
||||
HVX_Vector const_neg_one = hvx_vec_splat_f32(-1.0f);
|
||||
|
||||
// n = floor(x * (1/pi) + 0.5)
|
||||
HVX_Vector n_float = hvx_vec_floor_f32(hvx_vec_add_f32_f32(hvx_vec_mul_f32_f32(x, const_inv_pi), const_half));
|
||||
|
||||
// y = x - n * pi
|
||||
HVX_Vector y = hvx_vec_sub_f32_f32(x, hvx_vec_mul_f32_f32(n_float, const_pi));
|
||||
|
||||
// Sign determination: if n is odd, sign is -1.0f, else 1.0f
|
||||
// half_n = n * 0.5f
|
||||
HVX_Vector half_n = hvx_vec_mul_f32_f32(n_float, const_half);
|
||||
// floor_half_n = floor(half_n)
|
||||
HVX_Vector floor_half_n = hvx_vec_floor_f32(half_n);
|
||||
// is_odd = half_n > floor_half_n
|
||||
HVX_VectorPred is_odd = Q6_Q_vcmp_gt_VsfVsf(half_n, floor_half_n);
|
||||
// sign = vmux(is_odd, -1.0f, 1.0f)
|
||||
HVX_Vector sign = Q6_V_vmux_QVV(is_odd, const_neg_one, const_one);
|
||||
|
||||
// z = y^2
|
||||
HVX_Vector z = hvx_vec_mul_f32_f32(y, y);
|
||||
|
||||
// Chebyshev approximation for sin(y)
|
||||
HVX_Vector s4 = hvx_vec_splat_f32(2.642186986152672e-06f);
|
||||
HVX_Vector s3 = hvx_vec_splat_f32(-0.00019825318964070864f);
|
||||
HVX_Vector s2 = hvx_vec_splat_f32(0.00833326283319605f);
|
||||
HVX_Vector s1 = hvx_vec_splat_f32(-0.16666666082087775f);
|
||||
HVX_Vector s0 = hvx_vec_splat_f32(0.999999999915155f);
|
||||
|
||||
HVX_Vector sin_y = hvx_vec_add_f32_f32(s3, hvx_vec_mul_f32_f32(z, s4));
|
||||
sin_y = hvx_vec_add_f32_f32(s2, hvx_vec_mul_f32_f32(z, sin_y));
|
||||
sin_y = hvx_vec_add_f32_f32(s1, hvx_vec_mul_f32_f32(z, sin_y));
|
||||
sin_y = hvx_vec_add_f32_f32(s0, hvx_vec_mul_f32_f32(z, sin_y));
|
||||
sin_y = hvx_vec_mul_f32_f32(y, sin_y);
|
||||
|
||||
return hvx_vec_mul_f32_f32(sin_y, sign);
|
||||
}
|
||||
|
||||
#endif /* HVX_SIN_COS_H */
|
||||
|
|
@ -14,6 +14,8 @@
|
|||
#include "hvx-sqrt.h"
|
||||
#include "hvx-arith.h"
|
||||
#include "hvx-div.h"
|
||||
#include "hvx-floor.h"
|
||||
#include "hvx-sin-cos.h"
|
||||
#include "hvx-base.h"
|
||||
|
||||
#endif /* HVX_UTILS_H */
|
||||
|
|
|
|||
|
|
@ -420,8 +420,7 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
|
|||
|
||||
ctx->n_threads = n_hvx;
|
||||
for (int i = 0; i < ctx->n_threads; i++) {
|
||||
// see discussion https://github.com/ggml-org/llama.cpp/pull/18151#discussion_r2632388541
|
||||
ctx->dma[i] = dma_queue_create(128);
|
||||
ctx->dma[i] = dma_queue_create(256); // queue depth
|
||||
}
|
||||
|
||||
// init worker pool
|
||||
|
|
@ -601,6 +600,9 @@ static int execute_op(struct htp_ops_context * octx) {
|
|||
case HTP_OP_PAD:
|
||||
return op_pad(octx);
|
||||
|
||||
case HTP_OP_CONCAT:
|
||||
return op_concat(octx);
|
||||
|
||||
case HTP_OP_GATED_DELTA_NET:
|
||||
return op_gated_delta_net(octx);
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hvx-utils.h"
|
||||
|
|
@ -75,6 +76,9 @@ struct htp_rope_context {
|
|||
size_t theta_cache_offset;
|
||||
uint32_t src0_nrows;
|
||||
|
||||
struct fastdiv_values div_ne2_ne1;
|
||||
struct fastdiv_values div_ne1;
|
||||
|
||||
uint64_t t_start;
|
||||
};
|
||||
|
||||
|
|
@ -117,13 +121,84 @@ static __attribute__((noinline)) void rope_cache_init(const float theta_base,
|
|||
float * cache,
|
||||
const float theta_scale) {
|
||||
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
||||
float theta = theta_base;
|
||||
#if __HVX_ARCH__ >= 79
|
||||
const bool is_v79_or_newer = true;
|
||||
#else
|
||||
const bool is_v79_or_newer = false;
|
||||
#endif
|
||||
|
||||
for (uint32_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||
const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f;
|
||||
rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache);
|
||||
if (is_v79_or_newer && ext_factor == 0.0f) {
|
||||
// Fast path: fully vectorized
|
||||
// We process 32 pairs (64 elements) per iteration.
|
||||
const uint32_t n_blocks = ne0 / 64;
|
||||
|
||||
theta *= theta_scale;
|
||||
// Initialize theta scale powers: [1.0f, theta_scale, theta_scale^2, ..., theta_scale^31]
|
||||
float __attribute__((aligned(128))) theta_powers[32];
|
||||
theta_powers[0] = 1.0f;
|
||||
for (int j = 1; j < 32; j++) {
|
||||
theta_powers[j] = theta_powers[j - 1] * theta_scale;
|
||||
}
|
||||
HVX_Vector v_theta_powers = hvx_vmem(theta_powers);
|
||||
|
||||
HVX_Vector v_freq_scale = hvx_vec_splat_f32(freq_scale);
|
||||
HVX_Vector v_mscale = hvx_vec_splat_f32(mscale);
|
||||
|
||||
// Base theta starts at theta_base
|
||||
float theta_block = theta_base;
|
||||
// The scale factor for the next block is theta_scale^32
|
||||
float theta_scale_32 = 1.0f;
|
||||
for (int j = 0; j < 32; j++) {
|
||||
theta_scale_32 *= theta_scale;
|
||||
}
|
||||
|
||||
for (uint32_t b = 0; b < n_blocks; b++) {
|
||||
uint32_t i0 = b * 64;
|
||||
HVX_Vector v_theta_base = hvx_vec_splat_f32(theta_block);
|
||||
HVX_Vector v_theta = hvx_vec_mul_f32_f32(v_theta_base, v_theta_powers);
|
||||
|
||||
if (freq_factors) {
|
||||
// Load 32 elements of freq_factors
|
||||
HVX_Vector v_ff = hvx_vmemu(freq_factors + i0 / 2);
|
||||
HVX_Vector v_inv_ff = hvx_vec_inverse_f32(v_ff);
|
||||
v_theta = hvx_vec_mul_f32_f32(v_theta, v_inv_ff);
|
||||
}
|
||||
|
||||
HVX_Vector v_theta_final = hvx_vec_mul_f32_f32(v_theta, v_freq_scale);
|
||||
|
||||
HVX_Vector vcos = hvx_vec_cos_f32(v_theta_final);
|
||||
HVX_Vector vsin = hvx_vec_sin_f32(v_theta_final);
|
||||
|
||||
vcos = hvx_vec_mul_f32_f32(vcos, v_mscale);
|
||||
vsin = hvx_vec_mul_f32_f32(vsin, v_mscale);
|
||||
|
||||
HVX_VectorPair vstore = Q6_W_vshuff_VVR(vsin, vcos, -4);
|
||||
|
||||
if (((uintptr_t)cache) % 128 == 0) {
|
||||
hvx_vmem(cache + i0 + 0) = Q6_V_lo_W(vstore);
|
||||
hvx_vmem(cache + i0 + 32) = Q6_V_hi_W(vstore);
|
||||
} else {
|
||||
hvx_vec_store_u(cache + i0 + 0, 32 * sizeof(float), Q6_V_lo_W(vstore));
|
||||
hvx_vec_store_u(cache + i0 + 32, 32 * sizeof(float), Q6_V_hi_W(vstore));
|
||||
}
|
||||
|
||||
theta_block *= theta_scale_32;
|
||||
}
|
||||
|
||||
// Leftovers
|
||||
float theta = theta_block;
|
||||
for (uint32_t i0 = n_blocks * 64; i0 < ne0; i0 += 2) {
|
||||
const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f;
|
||||
rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache);
|
||||
theta *= theta_scale;
|
||||
}
|
||||
} else {
|
||||
// Fallback to original scalar loop
|
||||
float theta = theta_base;
|
||||
for (uint32_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||
const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f;
|
||||
rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache);
|
||||
theta *= theta_scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -195,24 +270,18 @@ static void rope_corr_dims(int n_dims,
|
|||
}
|
||||
|
||||
static inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) {
|
||||
const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0;
|
||||
const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache;
|
||||
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
|
||||
const uint32_t he = ne / 2;
|
||||
const uint32_t nvec = he / 32;
|
||||
const uint32_t nloe = he % 32;
|
||||
|
||||
uint32_t nvec = (ne / (VLEN_FP32 * 2) * 2); // 2 vecs per loop, step of 2
|
||||
for (uint32_t i = 0; i < nvec; i++) {
|
||||
HVX_Vector v0 = ((const HVX_Vector *) src0)[i];
|
||||
HVX_Vector v1 = hvx_vmemu(src0 + he + i * 32);
|
||||
|
||||
uint32_t he = ne / 2; // half_dims offset in elements
|
||||
uint32_t hv = he / VLEN_FP32; // half_dims offset in vectors
|
||||
HVX_Vector v2 = ((const HVX_Vector *) theta_cache)[i * 2 + 0];
|
||||
HVX_Vector v3 = ((const HVX_Vector *) theta_cache)[i * 2 + 1];
|
||||
|
||||
#pragma unroll(2)
|
||||
for (uint32_t i = 0; i < nvec; i += 2) {
|
||||
HVX_Vector v0 = vsrc[i/2+0];
|
||||
HVX_Vector v1 = vsrc[i/2+hv];
|
||||
|
||||
HVX_Vector v2 = vtheta[i+0];
|
||||
HVX_Vector v3 = vtheta[i+1];
|
||||
|
||||
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
|
||||
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4);
|
||||
|
||||
HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin));
|
||||
HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin));
|
||||
|
|
@ -222,37 +291,45 @@ static inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * rest
|
|||
HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
|
||||
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
|
||||
|
||||
vdst[i/2+0] = Q6_Vsf_equals_Vqf32(v4);
|
||||
vdst[i/2+hv] = Q6_Vsf_equals_Vqf32(v5);
|
||||
((HVX_Vector *) dst)[i] = Q6_Vsf_equals_Vqf32(v4);
|
||||
hvx_vmemu(dst + he + i * 32) = Q6_Vsf_equals_Vqf32(v5);
|
||||
}
|
||||
|
||||
for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) {
|
||||
const float cos_theta = theta_cache[i+0];
|
||||
const float sin_theta = theta_cache[i+1];
|
||||
float x0 = src0[i/2];
|
||||
float x1 = src0[i/2 + he];
|
||||
dst[i/2] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst[i/2 + he] = x0 * sin_theta + x1 * cos_theta;
|
||||
if (nloe > 0) {
|
||||
HVX_Vector v0 = hvx_vmemu(src0 + nvec * 32);
|
||||
HVX_Vector v1 = hvx_vmemu(src0 + he + nvec * 32);
|
||||
|
||||
HVX_Vector v2 = ((const HVX_Vector *) theta_cache)[nvec * 2 + 0];
|
||||
HVX_Vector v3 = ((const HVX_Vector *) theta_cache)[nvec * 2 + 1];
|
||||
|
||||
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4);
|
||||
|
||||
HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin));
|
||||
HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin));
|
||||
HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_lo_W(vcos_sin));
|
||||
HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_hi_W(vcos_sin));
|
||||
|
||||
HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
|
||||
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
|
||||
|
||||
hvx_vec_store_u(dst + nvec * 32, nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v4));
|
||||
hvx_vec_store_u(dst + he + nvec * 32, nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v5));
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) {
|
||||
const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0;
|
||||
const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache;
|
||||
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
|
||||
const uint32_t nvec = ne / 64;
|
||||
const uint32_t nloe = ne % 64;
|
||||
|
||||
uint32_t nvec = (ne / (VLEN_FP32 * 2)) * 2; // 2 vecs per loop, step of two
|
||||
for (uint32_t i = 0; i < nvec; i++) {
|
||||
HVX_Vector v0 = ((const HVX_Vector *) src0)[i * 2 + 0];
|
||||
HVX_Vector v1 = ((const HVX_Vector *) src0)[i * 2 + 1];
|
||||
|
||||
#pragma unroll(2)
|
||||
for (uint32_t i = 0; i < nvec; i+=2) {
|
||||
HVX_Vector v0 = vsrc[i+0];
|
||||
HVX_Vector v1 = vsrc[i+1];
|
||||
HVX_Vector v2 = ((const HVX_Vector *) theta_cache)[i * 2 + 0];
|
||||
HVX_Vector v3 = ((const HVX_Vector *) theta_cache)[i * 2 + 1];
|
||||
|
||||
HVX_Vector v2 = vtheta[i+0];
|
||||
HVX_Vector v3 = vtheta[i+1];
|
||||
|
||||
HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); // vx0_x1[0] = x0, vx0_x1[1] = x1
|
||||
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
|
||||
HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4);
|
||||
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4);
|
||||
|
||||
HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin));
|
||||
HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin));
|
||||
|
|
@ -264,17 +341,52 @@ static inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict
|
|||
|
||||
HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4);
|
||||
|
||||
vdst[i+0] = Q6_V_lo_W(vstore);
|
||||
vdst[i+1] = Q6_V_hi_W(vstore);
|
||||
((HVX_Vector *) dst)[i * 2 + 0] = Q6_V_lo_W(vstore);
|
||||
((HVX_Vector *) dst)[i * 2 + 1] = Q6_V_hi_W(vstore);
|
||||
}
|
||||
|
||||
for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) {
|
||||
const float cos_theta = theta_cache[i+0];
|
||||
const float sin_theta = theta_cache[i+1];
|
||||
float x0 = src0[i+0];
|
||||
float x1 = src0[i+1];
|
||||
dst[i+0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst[i+1] = x0 * sin_theta + x1 * cos_theta;
|
||||
if (nloe > 0) {
|
||||
if (nloe <= 32) {
|
||||
HVX_Vector v0 = hvx_vmemu(src0 + nvec * 64);
|
||||
HVX_Vector v2 = hvx_vmemu(theta_cache + nvec * 64);
|
||||
|
||||
HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(Q6_V_vzero(), v0, -4);
|
||||
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(Q6_V_vzero(), v2, -4);
|
||||
|
||||
HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin));
|
||||
HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin));
|
||||
HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_lo_W(vcos_sin));
|
||||
HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_hi_W(vcos_sin));
|
||||
|
||||
HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
|
||||
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
|
||||
|
||||
HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4);
|
||||
|
||||
hvx_vec_store_u(dst + nvec * 64, nloe * sizeof(float), Q6_V_lo_W(vstore));
|
||||
} else {
|
||||
HVX_Vector v0 = hvx_vmemu(src0 + nvec * 64);
|
||||
HVX_Vector v1 = hvx_vmemu(src0 + nvec * 64 + 32);
|
||||
|
||||
HVX_Vector v2 = hvx_vmemu(theta_cache + nvec * 64);
|
||||
HVX_Vector v3 = hvx_vmemu(theta_cache + nvec * 64 + 32);
|
||||
|
||||
HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4);
|
||||
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4);
|
||||
|
||||
HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin));
|
||||
HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin));
|
||||
HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_lo_W(vcos_sin));
|
||||
HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_hi_W(vcos_sin));
|
||||
|
||||
HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
|
||||
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
|
||||
|
||||
HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4);
|
||||
|
||||
((HVX_Vector *) dst)[nvec * 2 + 0] = Q6_V_lo_W(vstore);
|
||||
hvx_vec_store_u(dst + nvec * 64 + 32, (nloe - 32) * sizeof(float), Q6_V_hi_W(vstore));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -348,13 +460,19 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
|||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
const float * freq_factors = src2 ? (const float *) src2->data : NULL;
|
||||
|
||||
uint32_t ir = 0;
|
||||
const uint32_t i3_start = fastdiv(src0_start_row, &rctx->div_ne2_ne1);
|
||||
const uint32_t rem = fastmodulo(src0_start_row, ne2 * ne1, &rctx->div_ne2_ne1);
|
||||
const uint32_t i2_start = fastdiv(rem, &rctx->div_ne1);
|
||||
const uint32_t i1_start = fastmodulo(rem, ne1, &rctx->div_ne1);
|
||||
|
||||
uint32_t ir = src0_start_row;
|
||||
uint32_t prev_i2 = (uint32_t) -1;
|
||||
|
||||
for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
|
||||
for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
||||
for (uint32_t i1 = 0; i1 < ne1; ) { // attn-heads
|
||||
if (ir < src0_start_row) { ir++; i1++; continue; }
|
||||
for (uint32_t i3 = i3_start; i3 < ne3; i3++) { // batch
|
||||
const uint32_t i2_init = (i3 == i3_start) ? i2_start : 0;
|
||||
for (uint32_t i2 = i2_init; i2 < ne2; i2++) { // seq-len
|
||||
const uint32_t i1_init = (i3 == i3_start && i2 == i2_start) ? i1_start : 0;
|
||||
for (uint32_t i1 = i1_init; i1 < ne1; ) { // attn-heads
|
||||
if (ir >= src0_end_row) goto done;
|
||||
|
||||
// Rows in this block
|
||||
|
|
@ -407,9 +525,6 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
|||
ne0, rctx->ext_factor, rctx->attn_factor,
|
||||
theta_cache, rctx->theta_scale);
|
||||
}
|
||||
|
||||
// FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache,
|
||||
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
|
||||
}
|
||||
|
||||
// Skip output DMA transactions from prev block (if any)
|
||||
|
|
@ -489,7 +604,7 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
|
|||
// Aligned row sizes for VTCM
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 128);
|
||||
const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 256);
|
||||
|
||||
// Calculate spad sizes per thread
|
||||
size_t src0_spad_per_thread = theta_cache_size_aligned + HTP_ROPE_SPAD_NROWS * src0_row_size_aligned;
|
||||
|
|
@ -546,6 +661,11 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
|
|||
rctx.src0_nrows = src0_nrows;
|
||||
rctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
|
||||
|
||||
if (src0_nrows > 0) {
|
||||
rctx.div_ne2_ne1 = init_fastdiv_values(dst->ne[2] * dst->ne[1]);
|
||||
rctx.div_ne1 = init_fastdiv_values(dst->ne[1]);
|
||||
}
|
||||
|
||||
FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0,
|
||||
rctx.ext_factor, rctx.theta_scale, rctx.attn_factor);
|
||||
|
||||
|
|
|
|||
|
|
@ -65,6 +65,9 @@ static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da
|
|||
// parallelize by rows of src0
|
||||
const uint32_t dr = srctx->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
if (ir0 >= nr) {
|
||||
return;
|
||||
}
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32);
|
||||
|
|
@ -109,6 +112,9 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da
|
|||
// parallelize by rows of src0
|
||||
const uint32_t dr = srctx->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
if (ir0 >= nr) {
|
||||
return;
|
||||
}
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32);
|
||||
|
|
|
|||
|
|
@ -207,7 +207,7 @@ static void hvx_fast_norm_f32(const uint8_t * restrict src,
|
|||
|
||||
// scale = rsqrt(variance + epsilon), mean_x broadcast for subtraction
|
||||
HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(var_epsilon_v));
|
||||
HVX_Vector mean_x_b = hvx_vec_splat_f32(hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(mean_x_v)));
|
||||
HVX_Vector mean_x_b = hvx_vec_repl_f32(Q6_Vsf_equals_Vqf32(mean_x_v));
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < nvec; i++) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue