diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index d0e400133..67b6b05ca 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -273,67 +273,51 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G #if defined(GGML_SIMD) #if defined(__ARM_FEATURE_SVE) - const int sve_register_length = svcntb() * 8; //get vector length - const int ggml_f16_epr = sve_register_length / 16; // running when 16 - const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers + const int ggml_f16_epr = svcnth(); + const int ggml_f16_step = 8 * ggml_f16_epr; + const int np = n - (n % ggml_f16_step); + const int np2 = n - (n % ggml_f16_epr); - const int np= (n & ~(ggml_f16_step - 1)); - svfloat16_t sum1 = svdup_n_f16(0.0f); - svfloat16_t sum2 = svdup_n_f16(0.0f); - svfloat16_t sum3 = svdup_n_f16(0.0f); - svfloat16_t sum4 = svdup_n_f16(0.0f); + svfloat32_t sum1_lo = svdup_n_f32(0.0f); + svfloat32_t sum1_hi = svdup_n_f32(0.0f); + svfloat32_t sum2_lo = svdup_n_f32(0.0f); + svfloat32_t sum2_hi = svdup_n_f32(0.0f); + svfloat32_t sum3_lo = svdup_n_f32(0.0f); + svfloat32_t sum3_hi = svdup_n_f32(0.0f); + svfloat32_t sum4_lo = svdup_n_f32(0.0f); + svfloat32_t sum4_hi = svdup_n_f32(0.0f); - svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; - svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; for (int i = 0; i < np; i += ggml_f16_step) { - ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0); - ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); - sum1 = GGML_F16x_VEC_FMA(sum1, ax1, ay1); - - ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1); - ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); - sum2 = GGML_F16x_VEC_FMA(sum2, ax2, ay2); - - ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2); - ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); - sum3 = GGML_F16x_VEC_FMA(sum3, ax3, ay3); - - ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3); - ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); - sum4 = GGML_F16x_VEC_FMA(sum4, ax4, ay4); - - ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4); - ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); - sum1 = GGML_F16x_VEC_FMA(sum1, ax5, ay5); - - ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5); - ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); - sum2 = GGML_F16x_VEC_FMA(sum2, ax6, ay6); - - ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6); - ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); - sum3 = GGML_F16x_VEC_FMA(sum3, ax7, ay7); - - ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7); - ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); - sum4 = GGML_F16x_VEC_FMA(sum4, ax8, ay8); + ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0), GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0)); + ggml_sve_f16_fma_widened(&sum2_lo, &sum2_hi, GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1), GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1)); + ggml_sve_f16_fma_widened(&sum3_lo, &sum3_hi, GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2), GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2)); + ggml_sve_f16_fma_widened(&sum4_lo, &sum4_hi, GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3), GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3)); + ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4), GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4)); + ggml_sve_f16_fma_widened(&sum2_lo, &sum2_hi, GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5), GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5)); + ggml_sve_f16_fma_widened(&sum3_lo, &sum3_hi, GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6), GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6)); + ggml_sve_f16_fma_widened(&sum4_lo, &sum4_hi, GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7), GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7)); } - const int np2 = (n & ~(ggml_f16_epr - 1)); // round down to multiple of 8 - for (int k = np; k < np2; k += ggml_f16_epr) { - svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0); - svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); - sum1 = GGML_F16x_VEC_FMA(sum1, rx, ry); + for (int i = np; i < np2; i += ggml_f16_epr) { + ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i, 0), GGML_F16x_VEC_LOAD(y + i, 0)); } if (np2 < n) { - svbool_t pg = svwhilelt_b16(np2, n); - svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2)); - svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); + const svbool_t pg = svwhilelt_b16(np2, n); + const svfloat16_t rx = svld1_f16(pg, (const __fp16 *)(x + np2)); + const svfloat16_t ry = svld1_f16(pg, (const __fp16 *)(y + np2)); - sum1 = svmad_f16_x(pg, hx, hy, sum1); + ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, rx, ry); } - GGML_F16x_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4); + + sum1_lo = svadd_f32_m(DEFAULT_PG32, sum1_lo, sum2_lo); + sum1_hi = svadd_f32_m(DEFAULT_PG32, sum1_hi, sum2_hi); + sum3_lo = svadd_f32_m(DEFAULT_PG32, sum3_lo, sum4_lo); + sum3_hi = svadd_f32_m(DEFAULT_PG32, sum3_hi, sum4_hi); + sum1_lo = svadd_f32_m(DEFAULT_PG32, sum1_lo, sum3_lo); + sum1_hi = svadd_f32_m(DEFAULT_PG32, sum1_hi, sum3_hi); + + sumf = ggml_sve_sum_f32x2(sum1_lo, sum1_hi); #elif defined(__riscv_v_intrinsic) #if defined(__riscv_zvfh) int vl = __riscv_vsetvlmax_e32m2(); diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index bcd68da9a..5de9cb5b7 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -14,6 +14,35 @@ // floating point type used to accumulate sums typedef double ggml_float; +#if defined(__ARM_FEATURE_SVE) +inline static void ggml_sve_f16_fma_widened( + svfloat32_t * acc_lo, + svfloat32_t * acc_hi, + svfloat16_t x, + svfloat16_t y) { +#if defined(__ARM_FEATURE_SVE2) + *acc_lo = svmlalb_f32(*acc_lo, x, y); + *acc_hi = svmlalt_f32(*acc_hi, x, y); +#else + // Plain SVE fallback path if SVE2 instructions not available + svfloat16_t x_even = svtrn1_f16(x, x); + svfloat16_t x_odd = svtrn2_f16(x, x); + + svfloat16_t y_even = svtrn1_f16(y, y); + svfloat16_t y_odd = svtrn2_f16(y, y); + + svbool_t pg = svptrue_b32(); + + *acc_lo = svmla_f32_x(pg, *acc_lo, svcvt_f32_f16_x(pg, x_even), svcvt_f32_f16_x(pg, y_even)); + *acc_hi = svmla_f32_x(pg, *acc_hi, svcvt_f32_f16_x(pg, x_odd), svcvt_f32_f16_x(pg, y_odd)); +#endif +} + +inline static ggml_float ggml_sve_sum_f32x2(svfloat32_t sum_lo, svfloat32_t sum_hi) { + return (ggml_float) (svaddv_f32(svptrue_b32(), sum_lo) + svaddv_f32(svptrue_b32(), sum_hi)); +} +#endif + #define GGML_GELU_FP16 #define GGML_GELU_QUICK_FP16 @@ -122,108 +151,61 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG #if defined(GGML_SIMD) #if defined(__ARM_FEATURE_SVE) - const int sve_register_length = svcntb() * 8; - const int ggml_f16_epr = sve_register_length / 16; // running when 16 - const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers + const int ggml_f16_epr = svcnth(); + const int ggml_f16_step = 2 * ggml_f16_epr; + int np = n - (n % ggml_f16_step); + int np2 = n - (n % ggml_f16_epr); - int np = (n & ~(ggml_f16_step - 1)); - - svfloat16_t sum_00 = svdup_n_f16(0.0f); - svfloat16_t sum_01 = svdup_n_f16(0.0f); - svfloat16_t sum_02 = svdup_n_f16(0.0f); - svfloat16_t sum_03 = svdup_n_f16(0.0f); - - svfloat16_t sum_10 = svdup_n_f16(0.0f); - svfloat16_t sum_11 = svdup_n_f16(0.0f); - svfloat16_t sum_12 = svdup_n_f16(0.0f); - svfloat16_t sum_13 = svdup_n_f16(0.0f); - - svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; - svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; + svfloat32_t sum_0_0_lo = svdup_n_f32(0.0f); + svfloat32_t sum_0_0_hi = svdup_n_f32(0.0f); + svfloat32_t sum_0_1_lo = svdup_n_f32(0.0f); + svfloat32_t sum_0_1_hi = svdup_n_f32(0.0f); + svfloat32_t sum_1_0_lo = svdup_n_f32(0.0f); + svfloat32_t sum_1_0_hi = svdup_n_f32(0.0f); + svfloat32_t sum_1_1_lo = svdup_n_f32(0.0f); + svfloat32_t sum_1_1_hi = svdup_n_f32(0.0f); for (int i = 0; i < np; i += ggml_f16_step) { - ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements + const svfloat16_t ay0 = GGML_F16x_VEC_LOAD(y + i, 0); + const svfloat16_t ax00 = GGML_F16x_VEC_LOAD(x[0] + i, 0); + const svfloat16_t ax01 = GGML_F16x_VEC_LOAD(x[1] + i, 0); - ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elements - sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1 - ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements - sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1); + ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, ax00, ay0); + ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, ax01, ay0); - ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements + const svfloat16_t ay1 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 0); + const svfloat16_t ax10 = GGML_F16x_VEC_LOAD(x[0] + i + 1 * ggml_f16_epr, 0); + const svfloat16_t ax11 = GGML_F16x_VEC_LOAD(x[1] + i + 1 * ggml_f16_epr, 0); - ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 elements - sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2); - ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1); - sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2); - - ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); - - ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2); - sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3); - ax3 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2); - sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3); - - ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); - - ax4 = GGML_F16x_VEC_LOAD(x[0] + i + 3*ggml_f16_epr, 3); - sum_03 = GGML_F16x_VEC_FMA(sum_03, ax4, ay4); - ax4 = GGML_F16x_VEC_LOAD(x[1] + i + 3*ggml_f16_epr, 3); - sum_13 = GGML_F16x_VEC_FMA(sum_13, ax4, ay4); - - ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); - - ax5 = GGML_F16x_VEC_LOAD(x[0] + i + 4*ggml_f16_epr, 4); - - sum_00 = GGML_F16x_VEC_FMA(sum_00, ax5, ay5); - ax5 = GGML_F16x_VEC_LOAD(x[1] + i + 4*ggml_f16_epr, 4); - sum_10 = GGML_F16x_VEC_FMA(sum_10, ax5, ay5); - - ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); - - ax6 = GGML_F16x_VEC_LOAD(x[0] + i + 5*ggml_f16_epr, 5); - - sum_01 = GGML_F16x_VEC_FMA(sum_01, ax6, ay6); - ax6 = GGML_F16x_VEC_LOAD(x[1] + i + 5*ggml_f16_epr, 5); - sum_11 = GGML_F16x_VEC_FMA(sum_11, ax6, ay6); - - ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); - - ax7 = GGML_F16x_VEC_LOAD(x[0] + i + 6*ggml_f16_epr, 6); - - sum_02 = GGML_F16x_VEC_FMA(sum_02, ax7, ay7); - ax7 = GGML_F16x_VEC_LOAD(x[1] + i + 6*ggml_f16_epr, 6); - sum_12 = GGML_F16x_VEC_FMA(sum_12, ax7, ay7); - - ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); - - ax8 = GGML_F16x_VEC_LOAD(x[0] + i + 7*ggml_f16_epr, 7); - - sum_03 = GGML_F16x_VEC_FMA(sum_03, ax8, ay8); - ax8 = GGML_F16x_VEC_LOAD(x[1] + i + 7*ggml_f16_epr, 7); - sum_13 = GGML_F16x_VEC_FMA(sum_13, ax8, ay8); + ggml_sve_f16_fma_widened(&sum_0_1_lo, &sum_0_1_hi, ax10, ay1); + ggml_sve_f16_fma_widened(&sum_1_1_lo, &sum_1_1_hi, ax11, ay1); } - const int np2 = (n & ~(ggml_f16_epr - 1)); - for (int k = np; k < np2; k += ggml_f16_epr) { - svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); + for (int i = np; i < np2; i += ggml_f16_epr) { + const svfloat16_t ry = GGML_F16x_VEC_LOAD(y + i, 0); + const svfloat16_t rx0 = GGML_F16x_VEC_LOAD(x[0] + i, 0); + const svfloat16_t rx1 = GGML_F16x_VEC_LOAD(x[1] + i, 0); - svfloat16_t rx = GGML_F16x_VEC_LOAD(x[0] + k, 0); - sum_00 = GGML_F16x_VEC_FMA(sum_00, rx, ry); - rx = GGML_F16x_VEC_LOAD(x[1] + k, 0); - sum_10 = GGML_F16x_VEC_FMA(sum_10, rx, ry); + ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, rx0, ry); + ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, rx1, ry); } if (np2 < n) { - svbool_t pg = svwhilelt_b16(np2, n); - svfloat16_t hx_0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2)); - svfloat16_t hx_1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2)); - svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); + const svbool_t pg = svwhilelt_b16(np2, n); + const svfloat16_t ay = svld1_f16(pg, (const __fp16 *)(y + np2)); + const svfloat16_t ax0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2)); + const svfloat16_t ax1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2)); - sum_00 = svmad_f16_x(pg, hx_0, hy, sum_00); - sum_10 = svmad_f16_x(pg, hx_1, hy, sum_10); + ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, ax0, ay); + ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, ax1, ay); } - GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03); - GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13); + + svfloat32_t sum_0_lo = svadd_f32_x(DEFAULT_PG32, sum_0_0_lo, sum_0_1_lo); + svfloat32_t sum_0_hi = svadd_f32_x(DEFAULT_PG32, sum_0_0_hi, sum_0_1_hi); + svfloat32_t sum_1_lo = svadd_f32_x(DEFAULT_PG32, sum_1_0_lo, sum_1_1_lo); + svfloat32_t sum_1_hi = svadd_f32_x(DEFAULT_PG32, sum_1_0_hi, sum_1_1_hi); + sumf[0] = ggml_sve_sum_f32x2(sum_0_lo, sum_0_hi); + sumf[1] = ggml_sve_sum_f32x2(sum_1_lo, sum_1_hi); np = n; #elif defined(__riscv_v_intrinsic) #if defined(__riscv_zvfh)