diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 2f75e97ac..ebeef3bdb 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2870,6 +2870,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS; case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS; case GGML_OP_ARGSORT: return HTP_OP_ARGSORT; + case GGML_OP_NORM: return HTP_OP_NORM; case GGML_OP_L2_NORM: return HTP_OP_L2_NORM; case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM; case GGML_OP_SCALE: return HTP_OP_SCALE; @@ -3338,10 +3339,8 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_add_id(sess, op); break; + case GGML_OP_NORM: case GGML_OP_L2_NORM: - supp = ggml_hexagon_supported_unary(sess, op); - break; - case GGML_OP_RMS_NORM: case GGML_OP_SCALE: supp = ggml_hexagon_supported_unary(sess, op); diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 676e948a4..9d905a301 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -88,6 +88,7 @@ enum htp_op_code { HTP_OP_GATED_DELTA_NET, HTP_OP_TRI, HTP_OP_PAD, + HTP_OP_NORM, HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 12003c1fd..8e54536f6 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -534,6 +534,7 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_ADD_ID: return op_binary(octx); + case HTP_OP_NORM: case HTP_OP_RMS_NORM: case HTP_OP_SCALE: case HTP_OP_SQR: diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 1ce881353..40d2d6015 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -158,6 +158,79 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, } } +static void hvx_fast_norm_f32(const uint8_t * restrict src, + uint8_t * restrict dst, + uint8_t * restrict pad, + const int num_elems, + float epsilon) { + (void)pad; + + const HVX_Vector * restrict v_src = (HVX_Vector *) src; + HVX_Vector * restrict v_dst = (HVX_Vector *) dst; + + const int nvec = num_elems / VLEN_FP32; // number of full vectors + const int nloe = num_elems % VLEN_FP32; // leftover elements + + // Compute sum of squares and sum of values for full vectors + HVX_Vector sum_sq_v = Q6_V_vsplat_R(0x00000000); + HVX_Vector sum_x_v = Q6_V_vsplat_R(0x00000000); + HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon); + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2); + sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero())); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2); + sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero())); + } + + // Reduce HVX sums + sum_sq_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_sq_v)); + sum_x_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_x_v)); + + HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); + HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); + HVX_Vector mean_sq_v = Q6_Vqf32_vmpy_VsfVsf(sum_sq_v, denom_v); + HVX_Vector mean_x_v = Q6_Vqf32_vmpy_VsfVsf(sum_x_v, denom_v); + HVX_Vector mean_x_sq_v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(mean_x_v), Q6_Vsf_equals_Vqf32(mean_x_v)); + HVX_Vector var_v = Q6_Vqf32_vsub_Vqf32Vqf32(mean_sq_v, mean_x_sq_v); + HVX_Vector var_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(var_v, epsilon_v); + + // scale = rsqrt(variance + epsilon), mean_x broadcast for subtraction + HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(var_epsilon_v)); + HVX_Vector mean_x_b = hvx_vec_splat_f32(hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(mean_x_v))); + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b); + HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v); + v_dst[i] = Q6_Vsf_equals_Vqf32(v3); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b); + HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v); + HVX_Vector result = Q6_Vsf_equals_Vqf32(v3); + + // Store with masking to avoid overwriting memory beyond the tensor + hvx_vec_store_a(&v_dst[nvec], nloe * 4, result); + } +} + static void scale_f32(const float * restrict src, float * restrict dst, uint8_t * restrict spad, @@ -196,6 +269,24 @@ static void rms_norm_f32(const float * restrict src, } } +static void norm_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + float epsilon = 0.f; + memcpy(&epsilon, op_params, sizeof(float)); + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_fast_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon); + } +} + static void sqr_f32(const float * restrict src, float * restrict dst, uint8_t * restrict spad, @@ -556,6 +647,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * // Process block in VTCM switch (htp_op) { + case HTP_OP_NORM: + norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; case HTP_OP_RMS_NORM: rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; @@ -632,6 +726,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { const char * op_type = NULL; switch (octx->op) { + case HTP_OP_NORM: + op_type = "norm-f32"; + break; case HTP_OP_RMS_NORM: op_type = "rmsnorm-f32"; break;