hexagon: enable support for NORM op (#23319)

This commit is contained in:
Aparna M P 2026-05-19 22:18:21 +05:30 committed by GitHub
parent baf3cc6e1d
commit ac76808e4d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 101 additions and 3 deletions

View file

@ -2870,6 +2870,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) {
case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS;
case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS;
case GGML_OP_ARGSORT: return HTP_OP_ARGSORT;
case GGML_OP_NORM: return HTP_OP_NORM;
case GGML_OP_L2_NORM: return HTP_OP_L2_NORM;
case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM;
case GGML_OP_SCALE: return HTP_OP_SCALE;
@ -3338,10 +3339,8 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
supp = ggml_hexagon_supported_add_id(sess, op);
break;
case GGML_OP_NORM:
case GGML_OP_L2_NORM:
supp = ggml_hexagon_supported_unary(sess, op);
break;
case GGML_OP_RMS_NORM:
case GGML_OP_SCALE:
supp = ggml_hexagon_supported_unary(sess, op);

View file

@ -88,6 +88,7 @@ enum htp_op_code {
HTP_OP_GATED_DELTA_NET,
HTP_OP_TRI,
HTP_OP_PAD,
HTP_OP_NORM,
HTP_OP_INVALID
};

View file

@ -534,6 +534,7 @@ static int execute_op(struct htp_ops_context * octx) {
case HTP_OP_ADD_ID:
return op_binary(octx);
case HTP_OP_NORM:
case HTP_OP_RMS_NORM:
case HTP_OP_SCALE:
case HTP_OP_SQR:

View file

@ -158,6 +158,79 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
}
}
static void hvx_fast_norm_f32(const uint8_t * restrict src,
uint8_t * restrict dst,
uint8_t * restrict pad,
const int num_elems,
float epsilon) {
(void)pad;
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
const int nvec = num_elems / VLEN_FP32; // number of full vectors
const int nloe = num_elems % VLEN_FP32; // leftover elements
// Compute sum of squares and sum of values for full vectors
HVX_Vector sum_sq_v = Q6_V_vsplat_R(0x00000000);
HVX_Vector sum_x_v = Q6_V_vsplat_R(0x00000000);
HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
#pragma unroll(4)
for (int i = 0; i < nvec; i++) {
HVX_Vector v1 = v_src[i];
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2);
sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero()));
}
// Handle tail elements using vectorized ops with masking
if (nloe > 0) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2);
sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero()));
}
// Reduce HVX sums
sum_sq_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_sq_v));
sum_x_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_x_v));
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
HVX_Vector mean_sq_v = Q6_Vqf32_vmpy_VsfVsf(sum_sq_v, denom_v);
HVX_Vector mean_x_v = Q6_Vqf32_vmpy_VsfVsf(sum_x_v, denom_v);
HVX_Vector mean_x_sq_v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(mean_x_v), Q6_Vsf_equals_Vqf32(mean_x_v));
HVX_Vector var_v = Q6_Vqf32_vsub_Vqf32Vqf32(mean_sq_v, mean_x_sq_v);
HVX_Vector var_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(var_v, epsilon_v);
// scale = rsqrt(variance + epsilon), mean_x broadcast for subtraction
HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(var_epsilon_v));
HVX_Vector mean_x_b = hvx_vec_splat_f32(hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(mean_x_v)));
#pragma unroll(4)
for (int i = 0; i < nvec; i++) {
HVX_Vector v1 = v_src[i];
HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b);
HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v);
v_dst[i] = Q6_Vsf_equals_Vqf32(v3);
}
// Handle tail elements using vectorized ops with masking
if (nloe > 0) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b);
HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v);
HVX_Vector result = Q6_Vsf_equals_Vqf32(v3);
// Store with masking to avoid overwriting memory beyond the tensor
hvx_vec_store_a(&v_dst[nvec], nloe * 4, result);
}
}
static void scale_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
@ -196,6 +269,24 @@ static void rms_norm_f32(const float * restrict src,
}
}
static void norm_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params) {
float epsilon = 0.f;
memcpy(&epsilon, op_params, sizeof(float));
for (uint32_t ir = 0; ir < num_rows; ir++) {
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
hvx_fast_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
}
}
static void sqr_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
@ -556,6 +647,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
// Process block in VTCM
switch (htp_op) {
case HTP_OP_NORM:
norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
break;
case HTP_OP_RMS_NORM:
rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
break;
@ -632,6 +726,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
const char * op_type = NULL;
switch (octx->op) {
case HTP_OP_NORM:
op_type = "norm-f32";
break;
case HTP_OP_RMS_NORM:
op_type = "rmsnorm-f32";
break;