diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 9db99cb0f..1c8ecc197 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2874,6 +2874,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_NORM: return HTP_OP_NORM; case GGML_OP_L2_NORM: return HTP_OP_L2_NORM; case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM; + case GGML_OP_CONCAT: return HTP_OP_CONCAT; case GGML_OP_SCALE: return HTP_OP_SCALE; case GGML_OP_SQR: return HTP_OP_SQR; case GGML_OP_SQRT: return HTP_OP_SQRT; @@ -3286,6 +3287,25 @@ static bool ggml_hexagon_supported_repeat(const struct ggml_hexagon_session * se return true; } +static bool ggml_hexagon_supported_concat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + int dim = ((const int32_t *) op->op_params)[0]; + if (dim < 0 || dim >= GGML_MAX_DIMS) { + return false; + } + + for (int i = 0; i < GGML_MAX_SRC; ++i) { + const struct ggml_tensor * src = op->src[i]; + if (!src) { + continue; + } + if (src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_I32 && src->type != GGML_TYPE_F16) { + return false; + } + } + + return true; +} + static bool ggml_hexagon_supported_fill(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const struct ggml_tensor * dst = op; @@ -3434,6 +3454,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_cumsum(sess, op); break; + case GGML_OP_CONCAT: + supp = ggml_hexagon_supported_concat(sess, op); + break; + case GGML_OP_FILL: supp = ggml_hexagon_supported_fill(sess, op); break; diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 36f923243..33d67dda9 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -35,6 +35,7 @@ add_library(${HTP_LIB} SHARED ssm-conv.c cumsum-ops.c fill-ops.c + concat-ops.c diag-ops.c solve-tri-ops.c gated-delta-net-ops.c diff --git a/ggml/src/ggml-hexagon/htp/concat-ops.c b/ggml/src/ggml-hexagon/htp/concat-ops.c new file mode 100644 index 000000000..61580f2c0 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/concat-ops.c @@ -0,0 +1,275 @@ +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hexagon_types.h" +#include "hexagon_protos.h" +#include "hvx_hexagon_protos.h" +#include "hex-dma.h" +#include "vtcm-utils.h" +#include "hvx-utils.h" +#include "hex-fastdiv.h" +#include + +struct htp_concat_context { + struct htp_ops_context * octx; + uint32_t dim; + uint32_t nrows_per_thread; + struct fastdiv_values div_ne0; + struct fastdiv_values div_ne1; + struct fastdiv_values div_ne2; +}; + +static void concat_2d_f32_transposed(unsigned int nth, unsigned int ith, void * data) { + struct htp_concat_context * cctx = (struct htp_concat_context *) data; + struct htp_ops_context * octx = cctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; + + const uint32_t src0_ne0 = src0->ne[0]; + const uint32_t src1_ne0 = src1->ne[0]; + const uint32_t ne1 = dst->ne[1]; + + const uint32_t start_i = ith * cctx->nrows_per_thread; + const uint32_t end_i = (start_i + cctx->nrows_per_thread < ne1) ? (start_i + cctx->nrows_per_thread) : ne1; + if (start_i >= end_i) return; + + dma_queue * q = octx->ctx->dma[ith]; + + uint8_t * spad0_base = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; + uint8_t * spad1_base = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread; + + const uint32_t block_i = 32; + const uint32_t spad1_stride = block_i * sizeof(float); + + int32_t offsets[32] __attribute__((aligned(128))); + for(int k=0; k<32; k++) { + offsets[k] = k * spad1_stride; + } + HVX_Vector vv = *(HVX_Vector*)offsets; + const uint32_t src1_ne0_padded = hex_round_up(src1_ne0, 32); + const uint32_t spad0_row_bytes = hex_round_up((src0_ne0 + src1_ne0_padded) * sizeof(float), VLEN); + uint32_t mu = src1_ne0_padded * spad1_stride; + + for (uint32_t i = start_i; i < end_i; i += block_i) { + uint32_t current_block_i = (end_i - i < block_i) ? (end_i - i) : block_i; + + uint32_t src1_width_bytes = current_block_i * sizeof(float); + uint8_t * src1_ptr = (uint8_t *)src1->data + i * src1->nb[1]; + dma_queue_push(q, dma_make_ptr(spad1_base, src1_ptr), spad1_stride, src1->nb[0], src1_width_bytes, src1_ne0); + + uint32_t src0_row_bytes = src0_ne0 * sizeof(float); + uint8_t * src0_ptr = (uint8_t *)src0->data + i * src0->nb[1]; + dma_queue_push(q, dma_make_ptr(spad0_base, src0_ptr), spad0_row_bytes, src0->nb[1], src0_row_bytes, current_block_i); + + dma_queue_pop(q); // src1 + + HVX_Vector * vtcm_tmp = (HVX_Vector *)(spad1_base + src1_ne0_padded * spad1_stride); + + for (uint32_t j = 0; j < src1_ne0_padded; j += 32) { + #pragma unroll(4) + for (uint32_t ii = 0; ii < current_block_i; ii++) { + size_t rt = (size_t)(spad1_base + j * spad1_stride + ii * sizeof(float)); + Q6_vgather_ARMVw(&vtcm_tmp[ii], rt, mu, vv); + uint8_t * dst_ptr = spad0_base + ii * spad0_row_bytes + (src0_ne0 + j) * sizeof(float); + hvx_vmemu(dst_ptr) = vtcm_tmp[ii]; + } + } + + dma_queue_pop(q); // src0 + + uint8_t * dst_ptr = (uint8_t *)dst->data + i * dst->nb[1]; + dma_queue_push(q, dma_make_ptr(dst_ptr, spad0_base), dst->nb[1], spad0_row_bytes, (src0_ne0 + src1_ne0) * sizeof(float), current_block_i); + + dma_queue_pop(q); + } +} + +static void concat_2d_f16_transposed(unsigned int nth, unsigned int ith, void * data) { + struct htp_concat_context * cctx = (struct htp_concat_context *) data; + struct htp_ops_context * octx = cctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; + + const uint32_t src0_ne0 = src0->ne[0]; + const uint32_t src1_ne0 = src1->ne[0]; + const uint32_t ne1 = dst->ne[1]; + + const uint32_t start_i = ith * cctx->nrows_per_thread; + const uint32_t end_i = (start_i + cctx->nrows_per_thread < ne1) ? (start_i + cctx->nrows_per_thread) : ne1; + if (start_i >= end_i) return; + + dma_queue * q = octx->ctx->dma[ith]; + + uint8_t * spad0_base = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; + uint8_t * spad1_base = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread; + + const uint32_t block_i = 64; + const uint32_t spad1_stride = block_i * sizeof(__fp16); + + int16_t offsets[64] __attribute__((aligned(128))); + for(int k=0; k<64; k++) { + offsets[k] = k * spad1_stride; + } + HVX_Vector vv = *(HVX_Vector*)offsets; + const uint32_t src1_ne0_padded = hex_round_up(src1_ne0, 64); + const uint32_t spad0_row_bytes = hex_round_up((src0_ne0 + src1_ne0_padded) * sizeof(__fp16), VLEN); + uint32_t mu = src1_ne0_padded * spad1_stride; + + for (uint32_t i = start_i; i < end_i; i += block_i) { + uint32_t current_block_i = (end_i - i < block_i) ? (end_i - i) : block_i; + + uint32_t src1_width_bytes = current_block_i * sizeof(__fp16); + uint8_t * src1_ptr = (uint8_t *)src1->data + i * src1->nb[1]; + dma_queue_push(q, dma_make_ptr(spad1_base, src1_ptr), spad1_stride, src1->nb[0], src1_width_bytes, src1_ne0); + + uint32_t src0_row_bytes = src0_ne0 * sizeof(__fp16); + uint8_t * src0_ptr = (uint8_t *)src0->data + i * src0->nb[1]; + dma_queue_push(q, dma_make_ptr(spad0_base, src0_ptr), spad0_row_bytes, src0->nb[1], src0_row_bytes, current_block_i); + + dma_queue_pop(q); // src1 + + HVX_Vector * vtcm_tmp = (HVX_Vector *)(spad1_base + src1_ne0_padded * spad1_stride); + + for (uint32_t j = 0; j < src1_ne0_padded; j += 64) { + #pragma unroll(4) + for (uint32_t ii = 0; ii < current_block_i; ii++) { + size_t rt = (size_t)(spad1_base + j * spad1_stride + ii * sizeof(__fp16)); + Q6_vgather_ARMVh(&vtcm_tmp[ii], rt, mu, vv); + uint8_t * dst_ptr = spad0_base + ii * spad0_row_bytes + (src0_ne0 + j) * sizeof(__fp16); + hvx_vmemu(dst_ptr) = vtcm_tmp[ii]; + } + } + + dma_queue_pop(q); // src0 + + uint8_t * dst_ptr = (uint8_t *)dst->data + i * dst->nb[1]; + dma_queue_push(q, dma_make_ptr(dst_ptr, spad0_base), dst->nb[1], spad0_row_bytes, (src0_ne0 + src1_ne0) * sizeof(__fp16), current_block_i); + + dma_queue_pop(q); + } +} + +static void concat_generic(unsigned int nth, unsigned int ith, void * data) { + struct htp_concat_context * cctx = (struct htp_concat_context *) data; + struct htp_ops_context * octx = cctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; + + const int dim = cctx->dim; + const uint32_t type_size = (dst->type == HTP_TYPE_F32 || dst->type == HTP_TYPE_I32) ? 4 : 2; + + const uint32_t ne[4] = {dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]}; + const uint32_t total_elements = ne[0] * ne[1] * ne[2] * ne[3]; + const uint32_t chunk_size = (total_elements + nth - 1) / nth; + + const uint32_t start_idx = MIN(ith * chunk_size, total_elements); + const uint32_t end_idx = MIN(start_idx + chunk_size, total_elements); + + // Naive scalar element-wise copy + for (uint32_t idx = start_idx; idx < end_idx; idx++) { + uint32_t idx_div_ne0 = fastdiv(idx, &cctx->div_ne0); + uint32_t i0 = idx - idx_div_ne0 * ne[0]; + + uint32_t idx_div_ne01 = fastdiv(idx_div_ne0, &cctx->div_ne1); + uint32_t i1 = idx_div_ne0 - idx_div_ne01 * ne[1]; + + uint32_t idx_div_ne012 = fastdiv(idx_div_ne01, &cctx->div_ne2); + uint32_t i2 = idx_div_ne01 - idx_div_ne012 * ne[2]; + uint32_t i3 = idx_div_ne012; + + uint8_t * dst_ptr = (uint8_t *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2] + i1 * dst->nb[1] + i0 * dst->nb[0]; + + uint32_t idx_dim = 0; + if (dim == 0) idx_dim = i0; + else if (dim == 1) idx_dim = i1; + else if (dim == 2) idx_dim = i2; + else if (dim == 3) idx_dim = i3; + + const struct htp_tensor * src = (idx_dim < src0->ne[dim]) ? src0 : src1; + + uint32_t s0 = i0; + uint32_t s1 = i1; + uint32_t s2 = i2; + uint32_t s3 = i3; + + if (dim == 0 && src == src1) s0 -= src0->ne[0]; + if (dim == 1 && src == src1) s1 -= src0->ne[1]; + if (dim == 2 && src == src1) s2 -= src0->ne[2]; + if (dim == 3 && src == src1) s3 -= src0->ne[3]; + + uint8_t * src_ptr = (uint8_t *)src->data + s3 * src->nb[3] + s2 * src->nb[2] + s1 * src->nb[1] + s0 * src->nb[0]; + + if (type_size == 4) { + *(float*)dst_ptr = *(float*)src_ptr; + } else { + *(__fp16*)dst_ptr = *(__fp16*)src_ptr; + } + } +} + +int op_concat(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; + + int dim = octx->op_params[0]; + + bool is_2d = dst->ne[2] == 1 && dst->ne[3] == 1; + + const uint32_t type_size = (dst->type == HTP_TYPE_F32 || dst->type == HTP_TYPE_I32) ? 4 : 2; + bool is_src1_transposed = (src1->nb[0] > src1->nb[1]); + bool is_src0_transposed = (src0->nb[0] > src0->nb[1]); + + uint32_t n_threads = octx->n_threads; + struct htp_concat_context cctx; + cctx.octx = octx; + cctx.dim = dim; + cctx.div_ne0 = init_fastdiv_values(dst->ne[0]); + cctx.div_ne1 = init_fastdiv_values(dst->ne[1]); + cctx.div_ne2 = init_fastdiv_values(dst->ne[2]); + + void (*worker_func)(unsigned int, unsigned int, void *) = concat_generic; + + if (dim == 0 && is_2d && is_src1_transposed && !is_src0_transposed) { + n_threads = MIN(dst->ne[1], n_threads); + if (n_threads < 1) { + n_threads = 1; + } + uint32_t block_i = (type_size == 4) ? 32 : 64; + + cctx.nrows_per_thread = hmx_ceil_div(dst->ne[1], n_threads); + + // Allocate VTCM + uint32_t spad1_stride = block_i * type_size; + + uint32_t src1_ne0_padded = hex_round_up(src1->ne[0], block_i); + uint32_t spad0_row_bytes = hex_round_up((src0->ne[0] + src1_ne0_padded) * type_size, VLEN); + + octx->src0_spad.size_per_thread = block_i * spad0_row_bytes; + octx->src1_spad.size_per_thread = src1_ne0_padded * spad1_stride + block_i * VLEN; + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread; + + if (octx->src0_spad.size + octx->src1_spad.size > octx->ctx->vtcm_size) { + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + + if (type_size == 4) { + worker_func = concat_2d_f32_transposed; + } else { + worker_func = concat_2d_f16_transposed; + } + } + + worker_pool_run_func(octx->ctx->worker_pool, worker_func, &cctx, n_threads); + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/cpy-ops.c b/ggml/src/ggml-hexagon/htp/cpy-ops.c index 5c040a322..ae507effa 100644 --- a/ggml/src/ggml-hexagon/htp/cpy-ops.c +++ b/ggml/src/ggml-hexagon/htp/cpy-ops.c @@ -28,158 +28,170 @@ struct htp_copy_context { uint32_t dst_blocks_per_row; uint32_t src0_nrows_per_thread; - - void (*copy)(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith); }; #define cpy_preamble \ const struct htp_tensor *src0 = octx->src[0]; \ const struct htp_tensor *dst = octx->dst; \ \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ - const uint32_t nb3 = dst->nb[3]; \ - \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + \ const uint32_t nr = ne01; -static void cpy_thread_sametype_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) { - cpy_preamble; - - // parallelize by src0 rows - const uint32_t dr = ct->src0_nrows_per_thread; - const uint32_t ir0 = dr * ith; - const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; - - // copy by rows - for (uint32_t i03 = 0; i03 < ne03; i03++) { - for (uint32_t i02 = 0; i02 < ne02; i02++) { - #pragma unroll(2) - for (uint32_t i01 = ir0; i01 < ir1; i01++) { - uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3; - uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03; - hex_l2fetch(src0_ptr, ne00 * ct->src0_type_size, nb01, 2); - hvx_copy_uu(dst_ptr, src0_ptr, ne00, ct->src0_type_size); - } - } - } +#define DEFINE_CPY_SAMESHAPE(NAME, ELEM_TYPE, ELEM_SIZE) \ +static void cpy_thread_##NAME##_sameshape(unsigned int nth, unsigned int ith, void * data) { \ + struct htp_copy_context * ct = (struct htp_copy_context *) data; \ + struct htp_ops_context * octx = ct->octx; \ + cpy_preamble; \ + const uint32_t dr = ct->src0_nrows_per_thread; \ + const uint32_t ir0 = dr * ith; \ + const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; \ + if (ir0 >= nr) return; \ + for (uint32_t i03 = 0; i03 < ne03; i03++) { \ + for (uint32_t i02 = 0; i02 < ne02; i02++) { \ + _Pragma("unroll(4)") \ + for (uint32_t i01 = ir0; i01 < ir1; i01++) { \ + uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3; \ + uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03; \ + hex_l2fetch(src0_ptr, ne00 * ELEM_SIZE, nb01, 2); \ + hvx_copy_uu(dst_ptr, src0_ptr, ne00, ELEM_SIZE); \ + } \ + } \ + } \ } -static void cpy_thread_sametype_reshape(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith) { - cpy_preamble; +DEFINE_CPY_SAMESHAPE(f32, float, 4) +DEFINE_CPY_SAMESHAPE(f16, __fp16, 2) - // parallelize by src0 rows - const uint32_t dr = ct->src0_nrows_per_thread; - const uint32_t ir0 = dr * ith; - const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; - - // Fast path: when both src0 and dst are contiguous in memory - // Replace the element-by-element loop with a single bulk HVX copy per (i03, i02) slice. - const bool src0_contig = (nb00 == ct->src0_type_size) && - (nb01 == ne00 * nb00) && - (nb02 == ne01 * nb01) && - (nb03 == ne02 * nb02); - const bool dst_contig = (nb0 == ct->dst_type_size) && - (nb1 == ne0 * nb0) && - (nb2 == ne1 * nb1) && - (nb3 == ne2 * nb2); - - if (src0_contig && dst_contig) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - uint8_t * src_ptr = (uint8_t *) src0->data + i03*nb03 + i02*nb02 + ir0*nb01; - uint32_t flat = ((i03*ne02 + i02)*ne01 + ir0) * ne00; - uint8_t * dst_ptr = (uint8_t *) dst->data + flat * ct->src0_type_size; - hvx_copy_uu(dst_ptr, src_ptr, (ir1 - ir0) * ne00, ct->src0_type_size); - } - } - return; - } - - // dst counters - int64_t k10 = 0; - int64_t i11 = 0; - int64_t i12 = 0; - int64_t i13 = 0; - - // number of blocks in a row - const int64_t nk00 = ct->src0_blocks_per_row; - const int64_t nk0 = ct->dst_blocks_per_row; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - k10 += nk00 * ir0; - while (k10 >= nk0) { - k10 -= nk0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t k00 = 0; k00 < nk00; k00++) { - const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - memcpy(dst_ptr, src0_ptr, ct->dst_type_size); - - if (++k10 == nk0) { - k10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - k10 += nk00 * (ne01 - ir1); - while (k10 >= nk0) { - k10 -= nk0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } +#define DEFINE_CPY_RESHAPE(NAME, ELEM_TYPE, ELEM_SIZE) \ +static void cpy_thread_##NAME##_reshape(unsigned int nth, unsigned int ith, void * data) { \ + struct htp_copy_context * ct = (struct htp_copy_context *) data; \ + struct htp_ops_context * octx = ct->octx; \ + cpy_preamble; \ + const uint32_t dr = ct->src0_nrows_per_thread; \ + const uint32_t ir0 = dr * ith; \ + const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; \ + if (ir0 >= nr) return; \ + const bool src0_contig = (nb00 == ELEM_SIZE) && \ + (nb01 == ne00 * nb00) && \ + (nb02 == ne01 * nb01) && \ + (nb03 == ne02 * nb02); \ + const bool dst_contig = (nb0 == ELEM_SIZE) && \ + (nb1 == ne0 * nb0) && \ + (nb2 == ne1 * nb1) && \ + (nb3 == ne2 * nb2); \ + if (src0_contig && dst_contig) { \ + for (int64_t i03 = 0; i03 < ne03; i03++) { \ + for (int64_t i02 = 0; i02 < ne02; i02++) { \ + uint8_t * src_ptr = (uint8_t *) src0->data + i03*nb03 + i02*nb02 + ir0*nb01; \ + uint32_t flat = ((i03*ne02 + i02)*ne01 + ir0) * ne00; \ + uint8_t * dst_ptr = (uint8_t *) dst->data + flat * ELEM_SIZE; \ + hvx_copy_uu(dst_ptr, src_ptr, (ir1 - ir0) * ne00, ELEM_SIZE); \ + } \ + } \ + return; \ + } \ + const bool reshape_flat_fast = (ne03 == 1 && ne2 == 1 && ne3 == 1) && \ + (ne0 == ne00 * ne01) && (ne1 == ne02) && \ + (nb00 == ELEM_SIZE) && (nb0 == ELEM_SIZE); \ + if (reshape_flat_fast) { \ + for (uint32_t i02 = 0; i02 < ne02; i02++) { \ + for (uint32_t i01 = ir0; i01 < ir1; i01++) { \ + uint8_t * src0_ptr = (uint8_t *) src0->data + i01 * nb01 + i02 * nb02; \ + uint8_t * dst_ptr = (uint8_t *) dst->data + i01 * ne00 * ELEM_SIZE + i02 * nb1; \ + hvx_copy_uu(dst_ptr, src0_ptr, ne00, ELEM_SIZE); \ + } \ + } \ + return; \ + } \ + int64_t k10 = 0; \ + int64_t i11 = 0; \ + int64_t i12 = 0; \ + int64_t i13 = 0; \ + const int64_t nk00 = ct->src0_blocks_per_row; \ + const int64_t nk0 = ct->dst_blocks_per_row; \ + for (int64_t i03 = 0; i03 < ne03; i03++) { \ + for (int64_t i02 = 0; i02 < ne02; i02++) { \ + k10 += nk00 * ir0; \ + while (k10 >= nk0) { \ + k10 -= nk0; \ + if (++i11 == ne1) { \ + i11 = 0; \ + if (++i12 == ne2) { \ + i12 = 0; \ + if (++i13 == ne3) { \ + i13 = 0; \ + } \ + } \ + } \ + } \ + for (int64_t i01 = ir0; i01 < ir1; i01++) { \ + for (int64_t k00 = 0; k00 < nk00; k00++) { \ + const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); \ + char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); \ + memcpy(dst_ptr, src0_ptr, ELEM_SIZE); \ + if (++k10 == nk0) { \ + k10 = 0; \ + if (++i11 == ne1) { \ + i11 = 0; \ + if (++i12 == ne2) { \ + i12 = 0; \ + if (++i13 == ne3) { \ + i13 = 0; \ + } \ + } \ + } \ + } \ + } \ + } \ + k10 += nk00 * (ne01 - ir1); \ + while (k10 >= nk0) { \ + k10 -= nk0; \ + if (++i11 == ne1) { \ + i11 = 0; \ + if (++i12 == ne2) { \ + i12 = 0; \ + if (++i13 == ne3) { \ + i13 = 0; \ + } \ + } \ + } \ + } \ + } \ + } \ } -static void cpy_thread_f16_f32_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) { +DEFINE_CPY_RESHAPE(f32, float, 4) +DEFINE_CPY_RESHAPE(f16, __fp16, 2) + +static void cpy_thread_f16_f32_sameshape(unsigned int nth, unsigned int ith, void * data) { + struct htp_copy_context * ct = (struct htp_copy_context *) data; + struct htp_ops_context * octx = ct->octx; cpy_preamble; // parallelize by src0 rows const uint32_t dr = ct->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; + if (ir0 >= nr) return; // copy by rows for (uint32_t i03 = 0; i03 < ne03; i03++) { @@ -195,13 +207,16 @@ static void cpy_thread_f16_f32_sameshape(struct htp_copy_context * ct, struct ht } } -static void cpy_thread_f32_f16_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) { +static void cpy_thread_f32_f16_sameshape(unsigned int nth, unsigned int ith, void * data) { + struct htp_copy_context * ct = (struct htp_copy_context *) data; + struct htp_ops_context * octx = ct->octx; cpy_preamble; // parallelize by src0 rows const uint32_t dr = ct->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; + if (ir0 >= nr) return; // copy by rows for (uint32_t i03 = 0; i03 < ne03; i03++) { @@ -217,11 +232,6 @@ static void cpy_thread_f32_f16_sameshape(struct htp_copy_context * ct, struct ht } } -static void cpy_work_func(unsigned int n, unsigned int i, void *data) { - struct htp_copy_context *ct = (struct htp_copy_context *) data; - ct->copy(ct, ct->octx, n, i); -} - int op_cpy(struct htp_ops_context * octx) { cpy_preamble; @@ -254,22 +264,32 @@ int op_cpy(struct htp_ops_context * octx) { ct.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads; + worker_callback_t copy_fun; + if (sametype && sameshape) { - ct.copy = cpy_thread_sametype_sameshape; + if (src0->type == HTP_TYPE_F32) { + copy_fun = cpy_thread_f32_sameshape; + } else { + copy_fun = cpy_thread_f16_sameshape; + } } else if (sameshape) { /**/ if (dst->type == HTP_TYPE_F16 && src0->type == HTP_TYPE_F32) - ct.copy = cpy_thread_f16_f32_sameshape; + copy_fun = cpy_thread_f16_f32_sameshape; else if (dst->type == HTP_TYPE_F32 && src0->type == HTP_TYPE_F16) - ct.copy = cpy_thread_f32_f16_sameshape; + copy_fun = cpy_thread_f32_f16_sameshape; else return HTP_STATUS_NO_SUPPORT; } else if (sametype) { - ct.copy = cpy_thread_sametype_reshape; + if (src0->type == HTP_TYPE_F32) { + copy_fun = cpy_thread_f32_reshape; + } else { + copy_fun = cpy_thread_f16_reshape; + } } else { return HTP_STATUS_NO_SUPPORT; } - worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_threads); + worker_pool_run_func(octx->ctx->worker_pool, copy_fun, &ct, n_threads); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/ggml/src/ggml-hexagon/htp/get-rows-ops.c index 5a1dc9338..bf7063e98 100644 --- a/ggml/src/ggml-hexagon/htp/get-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/get-rows-ops.c @@ -17,9 +17,13 @@ struct get_rows_context { struct htp_ops_context * octx; - uint32_t src1_nrows_per_thread; + uint32_t tasks_per_thread; + uint32_t total_tasks; + uint32_t chunks_per_row; + uint32_t chunk_size; struct fastdiv_values get_rows_div_ne10; struct fastdiv_values get_rows_div_ne10_ne11; + struct fastdiv_values get_rows_div_chunks_per_row; }; #define get_rows_preamble \ @@ -52,20 +56,23 @@ struct get_rows_context { \ const uint32_t nr = ne10 * ne11 * ne12; -static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { +static void get_rows_thread_f32_f32_dma(unsigned int nth, unsigned int ith, void *data) { struct get_rows_context * grctx = (struct get_rows_context *)data; struct htp_ops_context * octx = grctx->octx; get_rows_preamble; uint64_t qt = HAP_perf_get_qtimer_count(); - // parallelize by src1 elements (which correspond to dst rows) - const uint32_t dr = grctx->src1_nrows_per_thread; + const uint32_t dr = grctx->tasks_per_thread; const uint32_t ir0 = dr * ith; - const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; + if (ir0 >= grctx->total_tasks) { + return; + } + const uint32_t ir1 = MIN(ir0 + dr, grctx->total_tasks); const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); + dma_queue * dma_queue = octx->ctx->dma[ith]; for (uint32_t i = ir0; i < ir1; ++i) { const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11); const uint32_t rem = i - i12 * ne11 * ne10; @@ -73,29 +80,77 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da const uint32_t i10 = rem - i11 * ne10; const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12; - uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; if (i01 >= ne01) { - // invalid index, skip for now to avoid crash continue; } const uintptr_t src0_ptr = octx->src[0]->data + i01*nb01 + i11*nb02 + i12*nb03; const uintptr_t dst_ptr = octx->dst->data + i10*nb1 + i11*nb2 + i12*nb3; - hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); + + while (!dma_queue_push(dma_queue, dma_make_ptr((void *)dst_ptr, (const void *)src0_ptr), nb1, nb01, ne00 * sizeof(float), 1)) { + dma_queue_pop(dma_queue); + } + } + dma_queue_flush(dma_queue); + + qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); + FARF(HIGH, "get-rows-f32-f32-dma %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt); +} + +static void get_rows_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) { + struct get_rows_context * grctx = (struct get_rows_context *)data; + struct htp_ops_context * octx = grctx->octx; + get_rows_preamble; + + uint64_t qt = HAP_perf_get_qtimer_count(); + + const uint32_t dr = grctx->tasks_per_thread; + const uint32_t ir0 = dr * ith; + if (ir0 >= grctx->total_tasks) { + return; + } + const uint32_t ir1 = MIN(ir0 + dr, grctx->total_tasks); + + const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); + + const uint32_t chunks_per_row = grctx->chunks_per_row; + const uint32_t chunk_size = grctx->chunk_size; + for (uint32_t i = ir0; i < ir1; ++i) { + const uint32_t row_idx = fastdiv(i, &grctx->get_rows_div_chunks_per_row); + const uint32_t chunk_idx = i - row_idx * chunks_per_row; + + const uint32_t i12 = fastdiv(row_idx, &grctx->get_rows_div_ne10_ne11); + const uint32_t rem = row_idx - i12 * ne11 * ne10; + const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10); + const uint32_t i10 = rem - i11 * ne10; + + const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12; + uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; + + if (i01 >= ne01) { + continue; + } + + const uint32_t offset = chunk_idx * chunk_size; + if (offset < ne00) { + const uint32_t copy_size = MIN(chunk_size, ne00 - offset); + const uintptr_t src0_ptr = octx->src[0]->data + i01*nb01 + i11*nb02 + i12*nb03 + offset * sizeof(float); + const uintptr_t dst_ptr = octx->dst->data + i10*nb1 + i11*nb2 + i12*nb3 + offset * sizeof(float); + hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, copy_size); + } } qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); - FARF(HIGH, "get-rows-f32-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + FARF(HIGH, "get-rows-f32-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt); } int op_get_rows(struct htp_ops_context * octx) { get_rows_preamble; - const uint32_t n_threads = MIN(nr, octx->n_threads); - if (octx->src[0]->type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } @@ -112,13 +167,52 @@ int op_get_rows(struct htp_ops_context * octx) { return HTP_STATUS_OK; } + const uint32_t nb00 = octx->src[0]->nb[0]; + const uint32_t nb0 = octx->dst->nb[0]; + + const bool can_use_dma = (nb00 == sizeof(float)) && (nb0 == sizeof(float)); + const bool use_dma = can_use_dma && (ne00 >= 2048); + struct get_rows_context grctx; grctx.octx = octx; grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src[1]->ne[0]); grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src[1]->ne[0] * octx->src[1]->ne[1]); - grctx.src1_nrows_per_thread = (nr + n_threads - 1) / n_threads; + if (use_dma) { + grctx.chunks_per_row = 1; + grctx.chunk_size = ne00; + grctx.total_tasks = nr; + grctx.get_rows_div_chunks_per_row = init_fastdiv_values(1); - worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_threads); + const uint32_t n_threads = MIN(nr, octx->n_threads); + grctx.tasks_per_thread = (nr + n_threads - 1) / n_threads; + + worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32_dma, &grctx, n_threads); + } else { + uint32_t chunks_per_row = 1; + uint32_t chunk_size = ne00; + uint32_t total_tasks = nr; + + if (nr < octx->n_threads) { + const uint32_t min_chunk_size = 1024; + uint32_t max_chunks = ne00 / min_chunk_size; + if (max_chunks == 0) { + max_chunks = 1; + } + chunks_per_row = MIN((octx->n_threads + nr - 1) / nr, max_chunks); + chunk_size = (ne00 + chunks_per_row - 1) / chunks_per_row; + total_tasks = nr * chunks_per_row; + } + + grctx.chunks_per_row = chunks_per_row; + grctx.chunk_size = chunk_size; + grctx.total_tasks = total_tasks; + grctx.get_rows_div_chunks_per_row = init_fastdiv_values(chunks_per_row); + + const uint32_t n_threads = MIN(total_tasks, octx->n_threads); + grctx.tasks_per_thread = (total_tasks + n_threads - 1) / n_threads; + + worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32_hvx, &grctx, n_threads); + } return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c index 9e1b778b0..a496f6289 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -50,8 +50,8 @@ static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS); const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK] const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong - const size_t k_dma_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K DMA: [Bc, DK] x2 double-buf - const size_t v_dma_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V DMA: [Bc, DV] x2 double-buf + const size_t k_dma_size = hex_align_up(Bc * hex_round_up(DK * sizeof(__fp16), 128), 4096); // K DMA: [Bc, DK] x2 double-buf + const size_t v_dma_size = hex_align_up(Bc * hex_round_up(DV * sizeof(__fp16), 128), 4096); // V DMA: [Bc, DV] x2 double-buf const size_t k_tile_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K tiles: [Bc, DK] interleaved const size_t v_tile_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V tiles: [Bc, DV] interleaved const size_t s_tile_size = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); // S/P:[g_br, Bc] @@ -1278,7 +1278,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { struct hmx_fa_context factx; memset(&factx, 0, sizeof(factx)); factx.octx = octx; - factx.n_threads = octx->ctx->n_threads; + factx.n_threads = n_threads; factx.DK = DK; factx.DV = DV; factx.n_kv = nek1; @@ -1328,10 +1328,15 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2); // ======== VTCM allocation (GQA-aware) ======== + const size_t size_k_row = DK * sizeof(__fp16); + const size_t size_v_row = DV * sizeof(__fp16); + const size_t size_k_row_padded = hex_round_up(size_k_row, 128); + const size_t size_v_row_padded = hex_round_up(size_v_row, 128); + const size_t q_tile_bytes = hex_align_up(g_br * DK * sizeof(__fp16), 4096); const size_t o_tile_bytes = hex_align_up(g_br * DV * sizeof(__fp16), 4096); - const size_t k_dma_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096); - const size_t v_dma_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096); + const size_t k_dma_bytes = hex_align_up(Bc * size_k_row_padded, 4096); + const size_t v_dma_bytes = hex_align_up(Bc * size_v_row_padded, 4096); const size_t k_tile_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096); const size_t v_tile_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096); const size_t s_tile_bytes = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); @@ -1401,11 +1406,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { // ======== DMA setup ======== dma_queue * const dma = ctx->dma[0]; - // Padded row sizes for DMA - const size_t size_k_row = nek0 * sizeof(__fp16); - const size_t size_v_row = nev0 * sizeof(__fp16); - const size_t size_k_row_padded = hex_round_up(nek0 * sizeof(__fp16), 128); - const size_t size_v_row_padded = hex_round_up(nev0 * sizeof(__fp16), 128); + // Padded row sizes for DMA (defined in outer scope) const size_t n_row_tiles_g_br = g_br / HMX_FP16_TILE_N_ROWS; const size_t n_tiles_per_bc = Bc / HMX_FP16_TILE_N_COLS; diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 6fe3e6c7d..51f9243ce 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -104,6 +104,7 @@ int op_argsort(struct htp_ops_context * octx); int op_ssm_conv(struct htp_ops_context * octx); int op_cumsum(struct htp_ops_context * octx); int op_fill(struct htp_ops_context * octx); +int op_concat(struct htp_ops_context * octx); int op_diag(struct htp_ops_context * octx); int op_solve_tri(struct htp_ops_context * octx); int op_gated_delta_net(struct htp_ops_context * octx); diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 9d905a301..54cfadd9b 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -89,6 +89,7 @@ enum htp_op_code { HTP_OP_TRI, HTP_OP_PAD, HTP_OP_NORM, + HTP_OP_CONCAT, HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h b/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h new file mode 100644 index 000000000..c5b9a5d47 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h @@ -0,0 +1,90 @@ +#ifndef HVX_SIN_COS_H +#define HVX_SIN_COS_H + +#include "hvx-base.h" +#include "hvx-floor.h" + +static inline HVX_Vector hvx_vec_cos_f32(HVX_Vector x) { + HVX_Vector const_inv_pi = hvx_vec_splat_f32(0.3183098861837907f); + HVX_Vector const_half = hvx_vec_splat_f32(0.5f); + HVX_Vector const_pi = hvx_vec_splat_f32(3.141592653589793f); + HVX_Vector const_one = hvx_vec_splat_f32(1.0f); + HVX_Vector const_neg_one = hvx_vec_splat_f32(-1.0f); + + // n = floor(x * (1/pi) + 0.5) + HVX_Vector n_float = hvx_vec_floor_f32(hvx_vec_add_f32_f32(hvx_vec_mul_f32_f32(x, const_inv_pi), const_half)); + + // y = x - n * pi + HVX_Vector y = hvx_vec_sub_f32_f32(x, hvx_vec_mul_f32_f32(n_float, const_pi)); + + // Sign determination: if n is odd, sign is -1.0f, else 1.0f + // half_n = n * 0.5f + HVX_Vector half_n = hvx_vec_mul_f32_f32(n_float, const_half); + // floor_half_n = floor(half_n) + HVX_Vector floor_half_n = hvx_vec_floor_f32(half_n); + // is_odd = half_n > floor_half_n + HVX_VectorPred is_odd = Q6_Q_vcmp_gt_VsfVsf(half_n, floor_half_n); + // sign = vmux(is_odd, -1.0f, 1.0f) + HVX_Vector sign = Q6_V_vmux_QVV(is_odd, const_neg_one, const_one); + + // z = y^2 + HVX_Vector z = hvx_vec_mul_f32_f32(y, y); + + // Chebyshev approximation for cos(y) + HVX_Vector c4 = hvx_vec_splat_f32(2.3557242013849433e-05f); + HVX_Vector c3 = hvx_vec_splat_f32(-0.0013871428263450528f); + HVX_Vector c2 = hvx_vec_splat_f32(0.041665895266688284f); + HVX_Vector c1 = hvx_vec_splat_f32(-0.4999999360426369f); + HVX_Vector c0 = hvx_vec_splat_f32(0.9999999999071725f); + + HVX_Vector cos_y = hvx_vec_add_f32_f32(c3, hvx_vec_mul_f32_f32(z, c4)); + cos_y = hvx_vec_add_f32_f32(c2, hvx_vec_mul_f32_f32(z, cos_y)); + cos_y = hvx_vec_add_f32_f32(c1, hvx_vec_mul_f32_f32(z, cos_y)); + cos_y = hvx_vec_add_f32_f32(c0, hvx_vec_mul_f32_f32(z, cos_y)); + + return hvx_vec_mul_f32_f32(cos_y, sign); +} + +static inline HVX_Vector hvx_vec_sin_f32(HVX_Vector x) { + HVX_Vector const_inv_pi = hvx_vec_splat_f32(0.3183098861837907f); + HVX_Vector const_half = hvx_vec_splat_f32(0.5f); + HVX_Vector const_pi = hvx_vec_splat_f32(3.141592653589793f); + HVX_Vector const_one = hvx_vec_splat_f32(1.0f); + HVX_Vector const_neg_one = hvx_vec_splat_f32(-1.0f); + + // n = floor(x * (1/pi) + 0.5) + HVX_Vector n_float = hvx_vec_floor_f32(hvx_vec_add_f32_f32(hvx_vec_mul_f32_f32(x, const_inv_pi), const_half)); + + // y = x - n * pi + HVX_Vector y = hvx_vec_sub_f32_f32(x, hvx_vec_mul_f32_f32(n_float, const_pi)); + + // Sign determination: if n is odd, sign is -1.0f, else 1.0f + // half_n = n * 0.5f + HVX_Vector half_n = hvx_vec_mul_f32_f32(n_float, const_half); + // floor_half_n = floor(half_n) + HVX_Vector floor_half_n = hvx_vec_floor_f32(half_n); + // is_odd = half_n > floor_half_n + HVX_VectorPred is_odd = Q6_Q_vcmp_gt_VsfVsf(half_n, floor_half_n); + // sign = vmux(is_odd, -1.0f, 1.0f) + HVX_Vector sign = Q6_V_vmux_QVV(is_odd, const_neg_one, const_one); + + // z = y^2 + HVX_Vector z = hvx_vec_mul_f32_f32(y, y); + + // Chebyshev approximation for sin(y) + HVX_Vector s4 = hvx_vec_splat_f32(2.642186986152672e-06f); + HVX_Vector s3 = hvx_vec_splat_f32(-0.00019825318964070864f); + HVX_Vector s2 = hvx_vec_splat_f32(0.00833326283319605f); + HVX_Vector s1 = hvx_vec_splat_f32(-0.16666666082087775f); + HVX_Vector s0 = hvx_vec_splat_f32(0.999999999915155f); + + HVX_Vector sin_y = hvx_vec_add_f32_f32(s3, hvx_vec_mul_f32_f32(z, s4)); + sin_y = hvx_vec_add_f32_f32(s2, hvx_vec_mul_f32_f32(z, sin_y)); + sin_y = hvx_vec_add_f32_f32(s1, hvx_vec_mul_f32_f32(z, sin_y)); + sin_y = hvx_vec_add_f32_f32(s0, hvx_vec_mul_f32_f32(z, sin_y)); + sin_y = hvx_vec_mul_f32_f32(y, sin_y); + + return hvx_vec_mul_f32_f32(sin_y, sign); +} + +#endif /* HVX_SIN_COS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index e0452811e..0a760cd34 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -14,6 +14,8 @@ #include "hvx-sqrt.h" #include "hvx-arith.h" #include "hvx-div.h" +#include "hvx-floor.h" +#include "hvx-sin-cos.h" #include "hvx-base.h" #endif /* HVX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index e86193884..f3a0866c7 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -420,8 +420,7 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que ctx->n_threads = n_hvx; for (int i = 0; i < ctx->n_threads; i++) { - // see discussion https://github.com/ggml-org/llama.cpp/pull/18151#discussion_r2632388541 - ctx->dma[i] = dma_queue_create(128); + ctx->dma[i] = dma_queue_create(256); // queue depth } // init worker pool @@ -601,6 +600,9 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_PAD: return op_pad(octx); + case HTP_OP_CONCAT: + return op_concat(octx); + case HTP_OP_GATED_DELTA_NET: return op_gated_delta_net(octx); diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index b398e19f0..c839044b8 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -7,6 +7,7 @@ #include #include +#include #include "hex-dma.h" #include "hvx-utils.h" @@ -75,6 +76,9 @@ struct htp_rope_context { size_t theta_cache_offset; uint32_t src0_nrows; + struct fastdiv_values div_ne2_ne1; + struct fastdiv_values div_ne1; + uint64_t t_start; }; @@ -117,13 +121,84 @@ static __attribute__((noinline)) void rope_cache_init(const float theta_base, float * cache, const float theta_scale) { // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py - float theta = theta_base; +#if __HVX_ARCH__ >= 79 + const bool is_v79_or_newer = true; +#else + const bool is_v79_or_newer = false; +#endif - for (uint32_t i0 = 0; i0 < ne0; i0 += 2) { - const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; - rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); + if (is_v79_or_newer && ext_factor == 0.0f) { + // Fast path: fully vectorized + // We process 32 pairs (64 elements) per iteration. + const uint32_t n_blocks = ne0 / 64; - theta *= theta_scale; + // Initialize theta scale powers: [1.0f, theta_scale, theta_scale^2, ..., theta_scale^31] + float __attribute__((aligned(128))) theta_powers[32]; + theta_powers[0] = 1.0f; + for (int j = 1; j < 32; j++) { + theta_powers[j] = theta_powers[j - 1] * theta_scale; + } + HVX_Vector v_theta_powers = hvx_vmem(theta_powers); + + HVX_Vector v_freq_scale = hvx_vec_splat_f32(freq_scale); + HVX_Vector v_mscale = hvx_vec_splat_f32(mscale); + + // Base theta starts at theta_base + float theta_block = theta_base; + // The scale factor for the next block is theta_scale^32 + float theta_scale_32 = 1.0f; + for (int j = 0; j < 32; j++) { + theta_scale_32 *= theta_scale; + } + + for (uint32_t b = 0; b < n_blocks; b++) { + uint32_t i0 = b * 64; + HVX_Vector v_theta_base = hvx_vec_splat_f32(theta_block); + HVX_Vector v_theta = hvx_vec_mul_f32_f32(v_theta_base, v_theta_powers); + + if (freq_factors) { + // Load 32 elements of freq_factors + HVX_Vector v_ff = hvx_vmemu(freq_factors + i0 / 2); + HVX_Vector v_inv_ff = hvx_vec_inverse_f32(v_ff); + v_theta = hvx_vec_mul_f32_f32(v_theta, v_inv_ff); + } + + HVX_Vector v_theta_final = hvx_vec_mul_f32_f32(v_theta, v_freq_scale); + + HVX_Vector vcos = hvx_vec_cos_f32(v_theta_final); + HVX_Vector vsin = hvx_vec_sin_f32(v_theta_final); + + vcos = hvx_vec_mul_f32_f32(vcos, v_mscale); + vsin = hvx_vec_mul_f32_f32(vsin, v_mscale); + + HVX_VectorPair vstore = Q6_W_vshuff_VVR(vsin, vcos, -4); + + if (((uintptr_t)cache) % 128 == 0) { + hvx_vmem(cache + i0 + 0) = Q6_V_lo_W(vstore); + hvx_vmem(cache + i0 + 32) = Q6_V_hi_W(vstore); + } else { + hvx_vec_store_u(cache + i0 + 0, 32 * sizeof(float), Q6_V_lo_W(vstore)); + hvx_vec_store_u(cache + i0 + 32, 32 * sizeof(float), Q6_V_hi_W(vstore)); + } + + theta_block *= theta_scale_32; + } + + // Leftovers + float theta = theta_block; + for (uint32_t i0 = n_blocks * 64; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; + rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); + theta *= theta_scale; + } + } else { + // Fallback to original scalar loop + float theta = theta_base; + for (uint32_t i0 = 0; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; + rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); + theta *= theta_scale; + } } } @@ -195,24 +270,18 @@ static void rope_corr_dims(int n_dims, } static inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) { - const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0; - const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache; - HVX_Vector * restrict vdst = (HVX_Vector *) dst; + const uint32_t he = ne / 2; + const uint32_t nvec = he / 32; + const uint32_t nloe = he % 32; - uint32_t nvec = (ne / (VLEN_FP32 * 2) * 2); // 2 vecs per loop, step of 2 + for (uint32_t i = 0; i < nvec; i++) { + HVX_Vector v0 = ((const HVX_Vector *) src0)[i]; + HVX_Vector v1 = hvx_vmemu(src0 + he + i * 32); - uint32_t he = ne / 2; // half_dims offset in elements - uint32_t hv = he / VLEN_FP32; // half_dims offset in vectors + HVX_Vector v2 = ((const HVX_Vector *) theta_cache)[i * 2 + 0]; + HVX_Vector v3 = ((const HVX_Vector *) theta_cache)[i * 2 + 1]; - #pragma unroll(2) - for (uint32_t i = 0; i < nvec; i += 2) { - HVX_Vector v0 = vsrc[i/2+0]; - HVX_Vector v1 = vsrc[i/2+hv]; - - HVX_Vector v2 = vtheta[i+0]; - HVX_Vector v3 = vtheta[i+1]; - - HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin)); HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin)); @@ -222,37 +291,45 @@ static inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * rest HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); - vdst[i/2+0] = Q6_Vsf_equals_Vqf32(v4); - vdst[i/2+hv] = Q6_Vsf_equals_Vqf32(v5); + ((HVX_Vector *) dst)[i] = Q6_Vsf_equals_Vqf32(v4); + hvx_vmemu(dst + he + i * 32) = Q6_Vsf_equals_Vqf32(v5); } - for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) { - const float cos_theta = theta_cache[i+0]; - const float sin_theta = theta_cache[i+1]; - float x0 = src0[i/2]; - float x1 = src0[i/2 + he]; - dst[i/2] = x0 * cos_theta - x1 * sin_theta; - dst[i/2 + he] = x0 * sin_theta + x1 * cos_theta; + if (nloe > 0) { + HVX_Vector v0 = hvx_vmemu(src0 + nvec * 32); + HVX_Vector v1 = hvx_vmemu(src0 + he + nvec * 32); + + HVX_Vector v2 = ((const HVX_Vector *) theta_cache)[nvec * 2 + 0]; + HVX_Vector v3 = ((const HVX_Vector *) theta_cache)[nvec * 2 + 1]; + + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); + + HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin)); + HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin)); + HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_lo_W(vcos_sin)); + HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_hi_W(vcos_sin)); + + HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); + + hvx_vec_store_u(dst + nvec * 32, nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v4)); + hvx_vec_store_u(dst + he + nvec * 32, nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v5)); } } static inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) { - const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0; - const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache; - HVX_Vector * restrict vdst = (HVX_Vector *) dst; + const uint32_t nvec = ne / 64; + const uint32_t nloe = ne % 64; - uint32_t nvec = (ne / (VLEN_FP32 * 2)) * 2; // 2 vecs per loop, step of two + for (uint32_t i = 0; i < nvec; i++) { + HVX_Vector v0 = ((const HVX_Vector *) src0)[i * 2 + 0]; + HVX_Vector v1 = ((const HVX_Vector *) src0)[i * 2 + 1]; - #pragma unroll(2) - for (uint32_t i = 0; i < nvec; i+=2) { - HVX_Vector v0 = vsrc[i+0]; - HVX_Vector v1 = vsrc[i+1]; + HVX_Vector v2 = ((const HVX_Vector *) theta_cache)[i * 2 + 0]; + HVX_Vector v3 = ((const HVX_Vector *) theta_cache)[i * 2 + 1]; - HVX_Vector v2 = vtheta[i+0]; - HVX_Vector v3 = vtheta[i+1]; - - HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); // vx0_x1[0] = x0, vx0_x1[1] = x1 - HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta + HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin)); HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin)); @@ -264,17 +341,52 @@ static inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4); - vdst[i+0] = Q6_V_lo_W(vstore); - vdst[i+1] = Q6_V_hi_W(vstore); + ((HVX_Vector *) dst)[i * 2 + 0] = Q6_V_lo_W(vstore); + ((HVX_Vector *) dst)[i * 2 + 1] = Q6_V_hi_W(vstore); } - for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) { - const float cos_theta = theta_cache[i+0]; - const float sin_theta = theta_cache[i+1]; - float x0 = src0[i+0]; - float x1 = src0[i+1]; - dst[i+0] = x0 * cos_theta - x1 * sin_theta; - dst[i+1] = x0 * sin_theta + x1 * cos_theta; + if (nloe > 0) { + if (nloe <= 32) { + HVX_Vector v0 = hvx_vmemu(src0 + nvec * 64); + HVX_Vector v2 = hvx_vmemu(theta_cache + nvec * 64); + + HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(Q6_V_vzero(), v0, -4); + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(Q6_V_vzero(), v2, -4); + + HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin)); + HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_hi_W(vcos_sin)); + + HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); + + HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4); + + hvx_vec_store_u(dst + nvec * 64, nloe * sizeof(float), Q6_V_lo_W(vstore)); + } else { + HVX_Vector v0 = hvx_vmemu(src0 + nvec * 64); + HVX_Vector v1 = hvx_vmemu(src0 + nvec * 64 + 32); + + HVX_Vector v2 = hvx_vmemu(theta_cache + nvec * 64); + HVX_Vector v3 = hvx_vmemu(theta_cache + nvec * 64 + 32); + + HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); + + HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin)); + HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_hi_W(vcos_sin)); + + HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); + + HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4); + + ((HVX_Vector *) dst)[nvec * 2 + 0] = Q6_V_lo_W(vstore); + hvx_vec_store_u(dst + nvec * 64 + 32, (nloe - 32) * sizeof(float), Q6_V_hi_W(vstore)); + } } } @@ -348,13 +460,19 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { const int32_t * pos = (const int32_t *) src1->data; const float * freq_factors = src2 ? (const float *) src2->data : NULL; - uint32_t ir = 0; + const uint32_t i3_start = fastdiv(src0_start_row, &rctx->div_ne2_ne1); + const uint32_t rem = fastmodulo(src0_start_row, ne2 * ne1, &rctx->div_ne2_ne1); + const uint32_t i2_start = fastdiv(rem, &rctx->div_ne1); + const uint32_t i1_start = fastmodulo(rem, ne1, &rctx->div_ne1); + + uint32_t ir = src0_start_row; uint32_t prev_i2 = (uint32_t) -1; - for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch - for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len - for (uint32_t i1 = 0; i1 < ne1; ) { // attn-heads - if (ir < src0_start_row) { ir++; i1++; continue; } + for (uint32_t i3 = i3_start; i3 < ne3; i3++) { // batch + const uint32_t i2_init = (i3 == i3_start) ? i2_start : 0; + for (uint32_t i2 = i2_init; i2 < ne2; i2++) { // seq-len + const uint32_t i1_init = (i3 == i3_start && i2 == i2_start) ? i1_start : 0; + for (uint32_t i1 = i1_init; i1 < ne1; ) { // attn-heads if (ir >= src0_end_row) goto done; // Rows in this block @@ -407,9 +525,6 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale); } - - // FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache, - // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); } // Skip output DMA transactions from prev block (if any) @@ -489,7 +604,7 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { // Aligned row sizes for VTCM const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); - const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 128); + const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 256); // Calculate spad sizes per thread size_t src0_spad_per_thread = theta_cache_size_aligned + HTP_ROPE_SPAD_NROWS * src0_row_size_aligned; @@ -546,6 +661,11 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { rctx.src0_nrows = src0_nrows; rctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; + if (src0_nrows > 0) { + rctx.div_ne2_ne1 = init_fastdiv_values(dst->ne[2] * dst->ne[1]); + rctx.div_ne1 = init_fastdiv_values(dst->ne[1]); + } + FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0, rctx.ext_factor, rctx.theta_scale, rctx.attn_factor); diff --git a/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/ggml/src/ggml-hexagon/htp/set-rows-ops.c index 0def7b408..58c54967d 100644 --- a/ggml/src/ggml-hexagon/htp/set-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/set-rows-ops.c @@ -65,6 +65,9 @@ static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da // parallelize by rows of src0 const uint32_t dr = srctx->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; + if (ir0 >= nr) { + return; + } const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); @@ -109,6 +112,9 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da // parallelize by rows of src0 const uint32_t dr = srctx->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; + if (ir0 >= nr) { + return; + } const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 40d2d6015..7d0431d8b 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -207,7 +207,7 @@ static void hvx_fast_norm_f32(const uint8_t * restrict src, // scale = rsqrt(variance + epsilon), mean_x broadcast for subtraction HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(var_epsilon_v)); - HVX_Vector mean_x_b = hvx_vec_splat_f32(hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(mean_x_v))); + HVX_Vector mean_x_b = hvx_vec_repl_f32(Q6_Vsf_equals_Vqf32(mean_x_v)); #pragma unroll(4) for (int i = 0; i < nvec; i++) {