mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-22 11:16:08 +00:00
hexagon: enable support for NORM op (#23319)
This commit is contained in:
parent
baf3cc6e1d
commit
ac76808e4d
4 changed files with 101 additions and 3 deletions
|
|
@ -2870,6 +2870,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) {
|
|||
case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS;
|
||||
case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS;
|
||||
case GGML_OP_ARGSORT: return HTP_OP_ARGSORT;
|
||||
case GGML_OP_NORM: return HTP_OP_NORM;
|
||||
case GGML_OP_L2_NORM: return HTP_OP_L2_NORM;
|
||||
case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM;
|
||||
case GGML_OP_SCALE: return HTP_OP_SCALE;
|
||||
|
|
@ -3338,10 +3339,8 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
|||
supp = ggml_hexagon_supported_add_id(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
supp = ggml_hexagon_supported_unary(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_SCALE:
|
||||
supp = ggml_hexagon_supported_unary(sess, op);
|
||||
|
|
|
|||
|
|
@ -88,6 +88,7 @@ enum htp_op_code {
|
|||
HTP_OP_GATED_DELTA_NET,
|
||||
HTP_OP_TRI,
|
||||
HTP_OP_PAD,
|
||||
HTP_OP_NORM,
|
||||
|
||||
HTP_OP_INVALID
|
||||
};
|
||||
|
|
|
|||
|
|
@ -534,6 +534,7 @@ static int execute_op(struct htp_ops_context * octx) {
|
|||
case HTP_OP_ADD_ID:
|
||||
return op_binary(octx);
|
||||
|
||||
case HTP_OP_NORM:
|
||||
case HTP_OP_RMS_NORM:
|
||||
case HTP_OP_SCALE:
|
||||
case HTP_OP_SQR:
|
||||
|
|
|
|||
|
|
@ -158,6 +158,79 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
|
|||
}
|
||||
}
|
||||
|
||||
static void hvx_fast_norm_f32(const uint8_t * restrict src,
|
||||
uint8_t * restrict dst,
|
||||
uint8_t * restrict pad,
|
||||
const int num_elems,
|
||||
float epsilon) {
|
||||
(void)pad;
|
||||
|
||||
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
|
||||
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
||||
|
||||
const int nvec = num_elems / VLEN_FP32; // number of full vectors
|
||||
const int nloe = num_elems % VLEN_FP32; // leftover elements
|
||||
|
||||
// Compute sum of squares and sum of values for full vectors
|
||||
HVX_Vector sum_sq_v = Q6_V_vsplat_R(0x00000000);
|
||||
HVX_Vector sum_x_v = Q6_V_vsplat_R(0x00000000);
|
||||
HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < nvec; i++) {
|
||||
HVX_Vector v1 = v_src[i];
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
||||
sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2);
|
||||
sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero()));
|
||||
}
|
||||
|
||||
// Handle tail elements using vectorized ops with masking
|
||||
if (nloe > 0) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
||||
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
||||
sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2);
|
||||
sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero()));
|
||||
}
|
||||
|
||||
// Reduce HVX sums
|
||||
sum_sq_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_sq_v));
|
||||
sum_x_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_x_v));
|
||||
|
||||
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
|
||||
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
|
||||
HVX_Vector mean_sq_v = Q6_Vqf32_vmpy_VsfVsf(sum_sq_v, denom_v);
|
||||
HVX_Vector mean_x_v = Q6_Vqf32_vmpy_VsfVsf(sum_x_v, denom_v);
|
||||
HVX_Vector mean_x_sq_v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(mean_x_v), Q6_Vsf_equals_Vqf32(mean_x_v));
|
||||
HVX_Vector var_v = Q6_Vqf32_vsub_Vqf32Vqf32(mean_sq_v, mean_x_sq_v);
|
||||
HVX_Vector var_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(var_v, epsilon_v);
|
||||
|
||||
// scale = rsqrt(variance + epsilon), mean_x broadcast for subtraction
|
||||
HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(var_epsilon_v));
|
||||
HVX_Vector mean_x_b = hvx_vec_splat_f32(hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(mean_x_v)));
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < nvec; i++) {
|
||||
HVX_Vector v1 = v_src[i];
|
||||
HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b);
|
||||
HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v);
|
||||
v_dst[i] = Q6_Vsf_equals_Vqf32(v3);
|
||||
}
|
||||
|
||||
// Handle tail elements using vectorized ops with masking
|
||||
if (nloe > 0) {
|
||||
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
||||
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
||||
HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b);
|
||||
HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v);
|
||||
HVX_Vector result = Q6_Vsf_equals_Vqf32(v3);
|
||||
|
||||
// Store with masking to avoid overwriting memory beyond the tensor
|
||||
hvx_vec_store_a(&v_dst[nvec], nloe * 4, result);
|
||||
}
|
||||
}
|
||||
|
||||
static void scale_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
|
|
@ -196,6 +269,24 @@ static void rms_norm_f32(const float * restrict src,
|
|||
}
|
||||
}
|
||||
|
||||
static void norm_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
float epsilon = 0.f;
|
||||
memcpy(&epsilon, op_params, sizeof(float));
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
||||
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
||||
|
||||
hvx_fast_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
|
||||
}
|
||||
}
|
||||
|
||||
static void sqr_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
|
|
@ -556,6 +647,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
|||
|
||||
// Process block in VTCM
|
||||
switch (htp_op) {
|
||||
case HTP_OP_NORM:
|
||||
norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_RMS_NORM:
|
||||
rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
|
|
@ -632,6 +726,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
|||
const char * op_type = NULL;
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_NORM:
|
||||
op_type = "norm-f32";
|
||||
break;
|
||||
case HTP_OP_RMS_NORM:
|
||||
op_type = "rmsnorm-f32";
|
||||
break;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue