ggml: fixed Arm SVE usage bug in vec.h, vec.cpp (#22841)

* Updated vec.h/vec.cpp code to accumulate to F32 rather than F16



Change-Id: I0cb789347f2bf60ffaf9047319f727e788c825f8

Signed-off-by: Martin Klacer <martin.klacer@arm.com>
Co-authored-by: Milos Puzovic <Milos.Puzovic@arm.com>
This commit is contained in:
Martin Klacer 2026-05-28 08:04:21 +01:00 committed by GitHub
parent 491c4d7d2e
commit e31cdaa0eb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 105 additions and 139 deletions

View file

@ -273,67 +273,51 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
#if defined(GGML_SIMD)
#if defined(__ARM_FEATURE_SVE)
const int sve_register_length = svcntb() * 8; //get vector length
const int ggml_f16_epr = sve_register_length / 16; // running when 16
const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
const int ggml_f16_epr = svcnth();
const int ggml_f16_step = 8 * ggml_f16_epr;
const int np = n - (n % ggml_f16_step);
const int np2 = n - (n % ggml_f16_epr);
const int np= (n & ~(ggml_f16_step - 1));
svfloat16_t sum1 = svdup_n_f16(0.0f);
svfloat16_t sum2 = svdup_n_f16(0.0f);
svfloat16_t sum3 = svdup_n_f16(0.0f);
svfloat16_t sum4 = svdup_n_f16(0.0f);
svfloat32_t sum1_lo = svdup_n_f32(0.0f);
svfloat32_t sum1_hi = svdup_n_f32(0.0f);
svfloat32_t sum2_lo = svdup_n_f32(0.0f);
svfloat32_t sum2_hi = svdup_n_f32(0.0f);
svfloat32_t sum3_lo = svdup_n_f32(0.0f);
svfloat32_t sum3_hi = svdup_n_f32(0.0f);
svfloat32_t sum4_lo = svdup_n_f32(0.0f);
svfloat32_t sum4_hi = svdup_n_f32(0.0f);
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
for (int i = 0; i < np; i += ggml_f16_step) {
ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
sum1 = GGML_F16x_VEC_FMA(sum1, ax1, ay1);
ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
sum2 = GGML_F16x_VEC_FMA(sum2, ax2, ay2);
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
sum3 = GGML_F16x_VEC_FMA(sum3, ax3, ay3);
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
sum4 = GGML_F16x_VEC_FMA(sum4, ax4, ay4);
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
sum1 = GGML_F16x_VEC_FMA(sum1, ax5, ay5);
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
sum2 = GGML_F16x_VEC_FMA(sum2, ax6, ay6);
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
sum3 = GGML_F16x_VEC_FMA(sum3, ax7, ay7);
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
sum4 = GGML_F16x_VEC_FMA(sum4, ax8, ay8);
ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0), GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0));
ggml_sve_f16_fma_widened(&sum2_lo, &sum2_hi, GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1), GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1));
ggml_sve_f16_fma_widened(&sum3_lo, &sum3_hi, GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2), GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2));
ggml_sve_f16_fma_widened(&sum4_lo, &sum4_hi, GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3), GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3));
ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4), GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4));
ggml_sve_f16_fma_widened(&sum2_lo, &sum2_hi, GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5), GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5));
ggml_sve_f16_fma_widened(&sum3_lo, &sum3_hi, GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6), GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6));
ggml_sve_f16_fma_widened(&sum4_lo, &sum4_hi, GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7), GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7));
}
const int np2 = (n & ~(ggml_f16_epr - 1)); // round down to multiple of 8
for (int k = np; k < np2; k += ggml_f16_epr) {
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
sum1 = GGML_F16x_VEC_FMA(sum1, rx, ry);
for (int i = np; i < np2; i += ggml_f16_epr) {
ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i, 0), GGML_F16x_VEC_LOAD(y + i, 0));
}
if (np2 < n) {
svbool_t pg = svwhilelt_b16(np2, n);
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
const svbool_t pg = svwhilelt_b16(np2, n);
const svfloat16_t rx = svld1_f16(pg, (const __fp16 *)(x + np2));
const svfloat16_t ry = svld1_f16(pg, (const __fp16 *)(y + np2));
sum1 = svmad_f16_x(pg, hx, hy, sum1);
ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, rx, ry);
}
GGML_F16x_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4);
sum1_lo = svadd_f32_m(DEFAULT_PG32, sum1_lo, sum2_lo);
sum1_hi = svadd_f32_m(DEFAULT_PG32, sum1_hi, sum2_hi);
sum3_lo = svadd_f32_m(DEFAULT_PG32, sum3_lo, sum4_lo);
sum3_hi = svadd_f32_m(DEFAULT_PG32, sum3_hi, sum4_hi);
sum1_lo = svadd_f32_m(DEFAULT_PG32, sum1_lo, sum3_lo);
sum1_hi = svadd_f32_m(DEFAULT_PG32, sum1_hi, sum3_hi);
sumf = ggml_sve_sum_f32x2(sum1_lo, sum1_hi);
#elif defined(__riscv_v_intrinsic)
#if defined(__riscv_zvfh)
int vl = __riscv_vsetvlmax_e32m2();

View file

@ -14,6 +14,35 @@
// floating point type used to accumulate sums
typedef double ggml_float;
#if defined(__ARM_FEATURE_SVE)
inline static void ggml_sve_f16_fma_widened(
svfloat32_t * acc_lo,
svfloat32_t * acc_hi,
svfloat16_t x,
svfloat16_t y) {
#if defined(__ARM_FEATURE_SVE2)
*acc_lo = svmlalb_f32(*acc_lo, x, y);
*acc_hi = svmlalt_f32(*acc_hi, x, y);
#else
// Plain SVE fallback path if SVE2 instructions not available
svfloat16_t x_even = svtrn1_f16(x, x);
svfloat16_t x_odd = svtrn2_f16(x, x);
svfloat16_t y_even = svtrn1_f16(y, y);
svfloat16_t y_odd = svtrn2_f16(y, y);
svbool_t pg = svptrue_b32();
*acc_lo = svmla_f32_x(pg, *acc_lo, svcvt_f32_f16_x(pg, x_even), svcvt_f32_f16_x(pg, y_even));
*acc_hi = svmla_f32_x(pg, *acc_hi, svcvt_f32_f16_x(pg, x_odd), svcvt_f32_f16_x(pg, y_odd));
#endif
}
inline static ggml_float ggml_sve_sum_f32x2(svfloat32_t sum_lo, svfloat32_t sum_hi) {
return (ggml_float) (svaddv_f32(svptrue_b32(), sum_lo) + svaddv_f32(svptrue_b32(), sum_hi));
}
#endif
#define GGML_GELU_FP16
#define GGML_GELU_QUICK_FP16
@ -122,108 +151,61 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
#if defined(GGML_SIMD)
#if defined(__ARM_FEATURE_SVE)
const int sve_register_length = svcntb() * 8;
const int ggml_f16_epr = sve_register_length / 16; // running when 16
const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
const int ggml_f16_epr = svcnth();
const int ggml_f16_step = 2 * ggml_f16_epr;
int np = n - (n % ggml_f16_step);
int np2 = n - (n % ggml_f16_epr);
int np = (n & ~(ggml_f16_step - 1));
svfloat16_t sum_00 = svdup_n_f16(0.0f);
svfloat16_t sum_01 = svdup_n_f16(0.0f);
svfloat16_t sum_02 = svdup_n_f16(0.0f);
svfloat16_t sum_03 = svdup_n_f16(0.0f);
svfloat16_t sum_10 = svdup_n_f16(0.0f);
svfloat16_t sum_11 = svdup_n_f16(0.0f);
svfloat16_t sum_12 = svdup_n_f16(0.0f);
svfloat16_t sum_13 = svdup_n_f16(0.0f);
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
svfloat32_t sum_0_0_lo = svdup_n_f32(0.0f);
svfloat32_t sum_0_0_hi = svdup_n_f32(0.0f);
svfloat32_t sum_0_1_lo = svdup_n_f32(0.0f);
svfloat32_t sum_0_1_hi = svdup_n_f32(0.0f);
svfloat32_t sum_1_0_lo = svdup_n_f32(0.0f);
svfloat32_t sum_1_0_hi = svdup_n_f32(0.0f);
svfloat32_t sum_1_1_lo = svdup_n_f32(0.0f);
svfloat32_t sum_1_1_hi = svdup_n_f32(0.0f);
for (int i = 0; i < np; i += ggml_f16_step) {
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements
const svfloat16_t ay0 = GGML_F16x_VEC_LOAD(y + i, 0);
const svfloat16_t ax00 = GGML_F16x_VEC_LOAD(x[0] + i, 0);
const svfloat16_t ax01 = GGML_F16x_VEC_LOAD(x[1] + i, 0);
ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elements
sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements
sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1);
ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, ax00, ay0);
ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, ax01, ay0);
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements
const svfloat16_t ay1 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 0);
const svfloat16_t ax10 = GGML_F16x_VEC_LOAD(x[0] + i + 1 * ggml_f16_epr, 0);
const svfloat16_t ax11 = GGML_F16x_VEC_LOAD(x[1] + i + 1 * ggml_f16_epr, 0);
ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 elements
sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2);
ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1);
sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2);
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2);
sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3);
ax3 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3);
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
ax4 = GGML_F16x_VEC_LOAD(x[0] + i + 3*ggml_f16_epr, 3);
sum_03 = GGML_F16x_VEC_FMA(sum_03, ax4, ay4);
ax4 = GGML_F16x_VEC_LOAD(x[1] + i + 3*ggml_f16_epr, 3);
sum_13 = GGML_F16x_VEC_FMA(sum_13, ax4, ay4);
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
ax5 = GGML_F16x_VEC_LOAD(x[0] + i + 4*ggml_f16_epr, 4);
sum_00 = GGML_F16x_VEC_FMA(sum_00, ax5, ay5);
ax5 = GGML_F16x_VEC_LOAD(x[1] + i + 4*ggml_f16_epr, 4);
sum_10 = GGML_F16x_VEC_FMA(sum_10, ax5, ay5);
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
ax6 = GGML_F16x_VEC_LOAD(x[0] + i + 5*ggml_f16_epr, 5);
sum_01 = GGML_F16x_VEC_FMA(sum_01, ax6, ay6);
ax6 = GGML_F16x_VEC_LOAD(x[1] + i + 5*ggml_f16_epr, 5);
sum_11 = GGML_F16x_VEC_FMA(sum_11, ax6, ay6);
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
ax7 = GGML_F16x_VEC_LOAD(x[0] + i + 6*ggml_f16_epr, 6);
sum_02 = GGML_F16x_VEC_FMA(sum_02, ax7, ay7);
ax7 = GGML_F16x_VEC_LOAD(x[1] + i + 6*ggml_f16_epr, 6);
sum_12 = GGML_F16x_VEC_FMA(sum_12, ax7, ay7);
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
ax8 = GGML_F16x_VEC_LOAD(x[0] + i + 7*ggml_f16_epr, 7);
sum_03 = GGML_F16x_VEC_FMA(sum_03, ax8, ay8);
ax8 = GGML_F16x_VEC_LOAD(x[1] + i + 7*ggml_f16_epr, 7);
sum_13 = GGML_F16x_VEC_FMA(sum_13, ax8, ay8);
ggml_sve_f16_fma_widened(&sum_0_1_lo, &sum_0_1_hi, ax10, ay1);
ggml_sve_f16_fma_widened(&sum_1_1_lo, &sum_1_1_hi, ax11, ay1);
}
const int np2 = (n & ~(ggml_f16_epr - 1));
for (int k = np; k < np2; k += ggml_f16_epr) {
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
for (int i = np; i < np2; i += ggml_f16_epr) {
const svfloat16_t ry = GGML_F16x_VEC_LOAD(y + i, 0);
const svfloat16_t rx0 = GGML_F16x_VEC_LOAD(x[0] + i, 0);
const svfloat16_t rx1 = GGML_F16x_VEC_LOAD(x[1] + i, 0);
svfloat16_t rx = GGML_F16x_VEC_LOAD(x[0] + k, 0);
sum_00 = GGML_F16x_VEC_FMA(sum_00, rx, ry);
rx = GGML_F16x_VEC_LOAD(x[1] + k, 0);
sum_10 = GGML_F16x_VEC_FMA(sum_10, rx, ry);
ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, rx0, ry);
ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, rx1, ry);
}
if (np2 < n) {
svbool_t pg = svwhilelt_b16(np2, n);
svfloat16_t hx_0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2));
svfloat16_t hx_1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2));
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
const svbool_t pg = svwhilelt_b16(np2, n);
const svfloat16_t ay = svld1_f16(pg, (const __fp16 *)(y + np2));
const svfloat16_t ax0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2));
const svfloat16_t ax1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2));
sum_00 = svmad_f16_x(pg, hx_0, hy, sum_00);
sum_10 = svmad_f16_x(pg, hx_1, hy, sum_10);
ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, ax0, ay);
ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, ax1, ay);
}
GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03);
GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13);
svfloat32_t sum_0_lo = svadd_f32_x(DEFAULT_PG32, sum_0_0_lo, sum_0_1_lo);
svfloat32_t sum_0_hi = svadd_f32_x(DEFAULT_PG32, sum_0_0_hi, sum_0_1_hi);
svfloat32_t sum_1_lo = svadd_f32_x(DEFAULT_PG32, sum_1_0_lo, sum_1_1_lo);
svfloat32_t sum_1_hi = svadd_f32_x(DEFAULT_PG32, sum_1_0_hi, sum_1_1_hi);
sumf[0] = ggml_sve_sum_f32x2(sum_0_lo, sum_0_hi);
sumf[1] = ggml_sve_sum_f32x2(sum_1_lo, sum_1_hi);
np = n;
#elif defined(__riscv_v_intrinsic)
#if defined(__riscv_zvfh)