Feature hexagon l2 norm (#22816)

* L2_NORM Updates

* Addressed PR Comments

* ggml-hexagon: add L2_NORM HVX kernel for Hexagon backend

* hex-unary: remove supported_unary_nc since the outer loop is the same for all unary ops

---------

Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
This commit is contained in:
Pranav Dhinakar 2026-05-08 13:41:40 -07:00 committed by GitHub
parent 49956041ee
commit b46812de78
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 91 additions and 2 deletions

View file

@ -2420,8 +2420,8 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
return false;
}
// TODO: add support for non-contiguous elements within a row
if (!ggml_is_contiguous_rows(src0) || !ggml_is_contiguous_rows(dst)) {
// dst must be contiguous; src0 may be non-contiguous
if (!ggml_is_contiguous(dst)) {
return false;
}
@ -2791,6 +2791,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) {
case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS;
case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS;
case GGML_OP_ARGSORT: return HTP_OP_ARGSORT;
case GGML_OP_L2_NORM: return HTP_OP_L2_NORM;
case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM;
case GGML_OP_SCALE: return HTP_OP_SCALE;
case GGML_OP_SQR: return HTP_OP_SQR;
@ -3253,6 +3254,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
supp = ggml_hexagon_supported_add_id(sess, op);
break;
case GGML_OP_L2_NORM:
supp = ggml_hexagon_supported_unary(sess, op);
break;
case GGML_OP_RMS_NORM:
case GGML_OP_SCALE:
supp = ggml_hexagon_supported_unary(sess, op);

View file

@ -83,6 +83,8 @@ enum htp_op_code {
HTP_OP_FILL,
HTP_OP_DIAG,
HTP_OP_SOLVE_TRI,
HTP_OP_L2_NORM,
HTP_OP_INVALID
};

View file

@ -542,6 +542,7 @@ static int execute_op(struct htp_ops_context * octx) {
case HTP_OP_UNARY_SIGMOID:
case HTP_OP_UNARY_NEG:
case HTP_OP_UNARY_EXP:
case HTP_OP_L2_NORM:
return op_unary(octx);
case HTP_OP_UNARY_SILU:

View file

@ -298,6 +298,81 @@ static void softplus_f32(const float * restrict src,
}
}
// --- L2_NORM HVX kernel ---
// Computes y[i] = x[i] / fmax(sqrt(sum(x[j]^2)), epsilon) for each row.
// scale = 1/fmax(sqrt(sum), epsilon) is computed entirely in HVX registers
// using rsqrt + inverse to avoid scalar extraction.
static void hvx_fast_l2_norm_f32(const uint8_t * restrict src,
uint8_t * restrict dst,
uint8_t * restrict pad,
const int num_elems,
float epsilon) {
(void)pad;
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
HVX_Vector sum_v = hvx_vec_splat_f32(0.0f);
const int nvec = num_elems / VLEN_FP32;
const int nloe = num_elems % VLEN_FP32;
#pragma unroll(4)
for (int i = 0; i < nvec; i++) {
HVX_Vector v1 = v_src[i];
HVX_Vector sq = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, sq);
}
// Include tail elements in the sum-of-squares using a predicate mask
if (nloe > 0) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
HVX_Vector sq = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, sq);
}
// Compute scale = 1/fmax(sqrt(sum), epsilon) entirely in HVX registers.
// hvx_vec_rsqrt_f32 + hvx_vec_inverse_f32 avoids scalar extraction.
HVX_Vector sum_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
HVX_Vector rsqrt_v = hvx_vec_rsqrt_f32(sum_sf); // 1/sqrt(sum)
HVX_Vector sqrt_v = hvx_vec_inverse_f32(rsqrt_v); // sqrt(sum)
HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
HVX_Vector denom_v = Q6_Vsf_vmax_VsfVsf(sqrt_v, epsilon_v); // fmax(sqrt(sum), epsilon)
HVX_Vector scale_v = hvx_vec_inverse_f32(denom_v); // 1/fmax(sqrt(sum), epsilon)
#pragma unroll(4)
for (int i = 0; i < nvec; i++) {
HVX_Vector v1 = v_src[i];
v_dst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(v1, scale_v));
}
if (nloe > 0) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
HVX_Vector result = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(v1, scale_v));
hvx_vec_store_a(&v_dst[nvec], nloe * 4, result);
}
}
static void l2_norm_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params) {
float epsilon = 0.f;
memcpy(&epsilon, op_params, sizeof(float));
for (uint32_t ir = 0; ir < num_rows; ir++) {
const float * restrict src_f = (const float *)((const uint8_t *)src + (ir * row_size));
float * restrict dst_f = (float *)((uint8_t *)dst + (ir * row_size));
hvx_fast_l2_norm_f32((const uint8_t *)src_f, (uint8_t *)dst_f, spad, row_elems, epsilon);
}
}
static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;
struct htp_ops_context * octx = uctx->octx;
@ -402,6 +477,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
case HTP_OP_UNARY_SOFTPLUS:
softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
break;
case HTP_OP_L2_NORM:
l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
break;
default:
break;
}
@ -469,6 +547,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
case HTP_OP_UNARY_SOFTPLUS:
op_type = "softplus-f32";
break;
case HTP_OP_L2_NORM:
op_type = "l2norm-f32";
break;
default:
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);