mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-17 04:09:19 +00:00
Feature hexagon l2 norm (#22816)
* L2_NORM Updates * Addressed PR Comments * ggml-hexagon: add L2_NORM HVX kernel for Hexagon backend * hex-unary: remove supported_unary_nc since the outer loop is the same for all unary ops --------- Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
This commit is contained in:
parent
49956041ee
commit
b46812de78
4 changed files with 91 additions and 2 deletions
|
|
@ -2420,8 +2420,8 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
|
|||
return false;
|
||||
}
|
||||
|
||||
// TODO: add support for non-contiguous elements within a row
|
||||
if (!ggml_is_contiguous_rows(src0) || !ggml_is_contiguous_rows(dst)) {
|
||||
// dst must be contiguous; src0 may be non-contiguous
|
||||
if (!ggml_is_contiguous(dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -2791,6 +2791,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) {
|
|||
case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS;
|
||||
case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS;
|
||||
case GGML_OP_ARGSORT: return HTP_OP_ARGSORT;
|
||||
case GGML_OP_L2_NORM: return HTP_OP_L2_NORM;
|
||||
case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM;
|
||||
case GGML_OP_SCALE: return HTP_OP_SCALE;
|
||||
case GGML_OP_SQR: return HTP_OP_SQR;
|
||||
|
|
@ -3253,6 +3254,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
|||
supp = ggml_hexagon_supported_add_id(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_L2_NORM:
|
||||
supp = ggml_hexagon_supported_unary(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_SCALE:
|
||||
supp = ggml_hexagon_supported_unary(sess, op);
|
||||
|
|
|
|||
|
|
@ -83,6 +83,8 @@ enum htp_op_code {
|
|||
HTP_OP_FILL,
|
||||
HTP_OP_DIAG,
|
||||
HTP_OP_SOLVE_TRI,
|
||||
HTP_OP_L2_NORM,
|
||||
|
||||
HTP_OP_INVALID
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -542,6 +542,7 @@ static int execute_op(struct htp_ops_context * octx) {
|
|||
case HTP_OP_UNARY_SIGMOID:
|
||||
case HTP_OP_UNARY_NEG:
|
||||
case HTP_OP_UNARY_EXP:
|
||||
case HTP_OP_L2_NORM:
|
||||
return op_unary(octx);
|
||||
|
||||
case HTP_OP_UNARY_SILU:
|
||||
|
|
|
|||
|
|
@ -298,6 +298,81 @@ static void softplus_f32(const float * restrict src,
|
|||
}
|
||||
}
|
||||
|
||||
// --- L2_NORM HVX kernel ---
|
||||
// Computes y[i] = x[i] / fmax(sqrt(sum(x[j]^2)), epsilon) for each row.
|
||||
// scale = 1/fmax(sqrt(sum), epsilon) is computed entirely in HVX registers
|
||||
// using rsqrt + inverse to avoid scalar extraction.
|
||||
static void hvx_fast_l2_norm_f32(const uint8_t * restrict src,
|
||||
uint8_t * restrict dst,
|
||||
uint8_t * restrict pad,
|
||||
const int num_elems,
|
||||
float epsilon) {
|
||||
(void)pad;
|
||||
|
||||
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
|
||||
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
||||
|
||||
HVX_Vector sum_v = hvx_vec_splat_f32(0.0f);
|
||||
|
||||
const int nvec = num_elems / VLEN_FP32;
|
||||
const int nloe = num_elems % VLEN_FP32;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < nvec; i++) {
|
||||
HVX_Vector v1 = v_src[i];
|
||||
HVX_Vector sq = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
||||
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, sq);
|
||||
}
|
||||
|
||||
// Include tail elements in the sum-of-squares using a predicate mask
|
||||
if (nloe > 0) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
||||
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
||||
HVX_Vector sq = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
||||
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, sq);
|
||||
}
|
||||
|
||||
// Compute scale = 1/fmax(sqrt(sum), epsilon) entirely in HVX registers.
|
||||
// hvx_vec_rsqrt_f32 + hvx_vec_inverse_f32 avoids scalar extraction.
|
||||
HVX_Vector sum_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
|
||||
HVX_Vector rsqrt_v = hvx_vec_rsqrt_f32(sum_sf); // 1/sqrt(sum)
|
||||
HVX_Vector sqrt_v = hvx_vec_inverse_f32(rsqrt_v); // sqrt(sum)
|
||||
HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
|
||||
HVX_Vector denom_v = Q6_Vsf_vmax_VsfVsf(sqrt_v, epsilon_v); // fmax(sqrt(sum), epsilon)
|
||||
HVX_Vector scale_v = hvx_vec_inverse_f32(denom_v); // 1/fmax(sqrt(sum), epsilon)
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < nvec; i++) {
|
||||
HVX_Vector v1 = v_src[i];
|
||||
v_dst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(v1, scale_v));
|
||||
}
|
||||
|
||||
if (nloe > 0) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
||||
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
||||
HVX_Vector result = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(v1, scale_v));
|
||||
hvx_vec_store_a(&v_dst[nvec], nloe * 4, result);
|
||||
}
|
||||
}
|
||||
|
||||
static void l2_norm_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
float epsilon = 0.f;
|
||||
memcpy(&epsilon, op_params, sizeof(float));
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_f = (const float *)((const uint8_t *)src + (ir * row_size));
|
||||
float * restrict dst_f = (float *)((uint8_t *)dst + (ir * row_size));
|
||||
|
||||
hvx_fast_l2_norm_f32((const uint8_t *)src_f, (uint8_t *)dst_f, spad, row_elems, epsilon);
|
||||
}
|
||||
}
|
||||
|
||||
static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;
|
||||
struct htp_ops_context * octx = uctx->octx;
|
||||
|
|
@ -402,6 +477,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
|||
case HTP_OP_UNARY_SOFTPLUS:
|
||||
softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_L2_NORM:
|
||||
l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
@ -469,6 +547,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
|||
case HTP_OP_UNARY_SOFTPLUS:
|
||||
op_type = "softplus-f32";
|
||||
break;
|
||||
case HTP_OP_L2_NORM:
|
||||
op_type = "l2norm-f32";
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue