From 609ea50026a336a6cf3c02e596792477530b5928 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Sat, 14 Mar 2026 11:09:08 -0700 Subject: [PATCH 1/7] hexagon: Q4_0 and MXFP4 repack fixes (#20527) * hexagon: fix tail corruption with rows sizes not multiple of 256 * hexagon: use different stride for repacking partial blocks * hex-mm: update repack and kernels to avoid shuffles for full 256-element blocks Previous commit changed the repacking to use even:odd (0:1,2:3,..) packing instead of the original (0:128,1:129,...) packing in order to fix tail corruption. Since the mm kernels already deal with partial tails we can use even:odd packing only for the last block. This avoid performance penalty of having to shuffle to zip the elements in the common case. * hex-mm: update rmpy x8 for better optimizations * hex-mm: tighten supported MUL_MAT checks to avoid spurios failures * hex-mm: use vzero to init accumulators * hex-mm: properly call partial rmpy_x8 --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 43 ++- ggml/src/ggml-hexagon/htp/matmul-ops.c | 407 +++++++++++++++---------- 2 files changed, 287 insertions(+), 163 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index d6e9776b8..19917cb11 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -402,6 +402,7 @@ static void pack_q4_0_quants(block_q4_0 * x, const uint8_t * qs, unsigned int bi static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) { static const int qk = QK_Q4_0x4x2; const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers const int dblk_size = 8 * 2; // 8x __fp16 const int qblk_size = qk / 2; // int4 @@ -435,9 +436,11 @@ static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) { unpack_q4_0_quants(qs, &x[i * 8 + 6], 6); unpack_q4_0_quants(qs, &x[i * 8 + 7], 7); + bool partial = (nloe && i == nb-1); + uint8_t * q = y_q + (i * qblk_size); for (int j = 0; j < qk / 2; j++) { - q[j] = (qs[j + 128] << 4) | qs[j]; + q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000]; } } @@ -467,6 +470,7 @@ static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) { static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) { static const int qk = QK_Q4_0x4x2; const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers const int dblk_size = 8 * 2; // 8x __fp16 const int qblk_size = qk / 2; // int4 @@ -485,10 +489,17 @@ static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) { for (int i = 0; i < nb; i++) { uint8_t qs[QK_Q4_0x4x2]; // unpacked quants + bool partial = (nloe && i == nb-1); + const uint8_t * q = y_q + (i * qblk_size); for (int j = 0; j < qk / 2; j++) { - qs[j] = q[j] & 0xf; - qs[j + 128] = q[j] >> 4; + if (partial) { + qs[j*2+0] = q[j] & 0xf; + qs[j*2+1] = q[j] >> 4; + } else { + qs[j+000] = q[j] & 0xf; + qs[j+128] = q[j] >> 4; + } } pack_q4_0_quants(&x[i * 8 + 0], qs, 0); @@ -1078,6 +1089,7 @@ static void pack_mxfp4_quants(block_mxfp4 * x, const uint8_t * qs, unsigned int static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) { static const int qk = QK_MXFP4x4x2; const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers const int eblk_size = 8 * 1; // 8x E8M0 const int qblk_size = qk / 2; // int4 @@ -1112,9 +1124,11 @@ static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) unpack_mxfp4_quants(qs, &x[i * 8 + 6], 6); unpack_mxfp4_quants(qs, &x[i * 8 + 7], 7); + bool partial = (nloe && i == nb-1); + uint8_t * q = y_q + (i * qblk_size); for (int j = 0; j < qk / 2; j++) { - q[j] = (qs[j + 128] << 4) | qs[j]; + q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000]; } } @@ -1144,6 +1158,7 @@ static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) { static const int qk = QK_MXFP4x4x2; const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers const int eblk_size = 8 * 1; // 8x E8M0 const int qblk_size = qk / 2; // int4 @@ -1162,10 +1177,17 @@ static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) for (int i = 0; i < nb; i++) { uint8_t qs[QK_MXFP4x4x2]; // unpacked quants + bool partial = (nloe && i == nb-1); + const uint8_t * q = y_q + (i * qblk_size); for (int j = 0; j < qk / 2; j++) { - qs[j] = q[j] & 0xf; - qs[j + 128] = q[j] >> 4; + if (partial) { + qs[j*2+0] = q[j] & 0xf; + qs[j*2+1] = q[j] >> 4; + } else { + qs[j+000] = q[j] & 0xf; + qs[j+128] = q[j] >> 4; + } } pack_mxfp4_quants(&x[i * 8 + 0], qs, 0); @@ -1801,12 +1823,12 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s return false; } - if (src0->ne[1] > 16 * 1024) { + if (ggml_nrows(src0) > 16 * 1024) { return false; // typically the lm-head which would be too large for VTCM } - if ((src1->ne[2] != 1 || src1->ne[3] != 1)) { - return false; + if (ggml_nrows(src1) > 1024 || src1->ne[2] != 1 || src1->ne[3] != 1) { + return false; // no huge batches or broadcasting (for now) } // src0 (weights) must be repacked @@ -1820,6 +1842,9 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F16 src0 not supported\n"); return false; } + if (ggml_nrows(src1) > 1024) { + return false; // no huge batches (for now) + } break; default: diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 9ca74aedf..73aaba79e 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -77,7 +77,7 @@ static inline size_t q8x4x2_row_size(uint32_t ne) { return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128); } -static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) { +static inline HVX_Vector_x8 hvx_vec_load_q4x4x8_full(const uint8_t * restrict ptr) { const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) @@ -88,9 +88,9 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) { const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); const HVX_Vector i8 = Q6_Vb_vsplat_R(8); - HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 - HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F + HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 : second 128 elements + HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F ... HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 @@ -111,7 +111,41 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) { return r; } -static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) { +static HVX_Vector_x8 hvx_vec_load_q4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + const uint32_t qk = QK_Q4_0x4x2; // 256 + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector i8 = Q6_Vb_vsplat_R(8); + + HVX_Vector_x8 r; + uint32_t i = 0; + + #pragma unroll(2) + for (i=0; i < nb; i++) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements + r.v[i*2+0] = Q6_Vb_vsub_VbVb(v0, i8); + r.v[i*2+1] = Q6_Vb_vsub_VbVb(v1, i8); + } + + if (nloe) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements + HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... + r.v[i*2+0] = Q6_Vb_vsub_VbVb(Q6_V_lo_W(v0_1_p), i8); + r.v[i*2+1] = Q6_Vb_vsub_VbVb(Q6_V_hi_W(v0_1_p), i8); + } + + return r; +} + +static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_full(const uint8_t * restrict ptr) { const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) @@ -144,7 +178,41 @@ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) return r; } -static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) { +static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + const uint32_t qk = QK_Q4_0x4x2; // 256 + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; + + HVX_Vector_x8 r; + uint32_t i = 0; + + #pragma unroll(2) + for (i=0; i < nb; i++) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements + r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0); + r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0); + } + + if (nloe) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements + HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... + r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0); + r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0); + } + + return r; +} + +static inline HVX_Vector_x8 hvx_vec_load_q8x4x8_full(const uint8_t * restrict ptr) { const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; HVX_Vector v0 = vptr[0]; // first 128 vals @@ -160,6 +228,10 @@ static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) { return r; } +static inline HVX_Vector_x8 hvx_vec_load_q8x4x8_partial(const uint8_t * restrict ptr, uint32_t nloe) { + return hvx_vec_load_q8x4x8_full(ptr); +} + // Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors). // Accumulate each block into a single int32 value. // Return a single HVX vector with 32x int32 accumulators. @@ -167,14 +239,14 @@ static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) { // if() checks are optimized out at compile time -- make sure to pass N as a constexpr. static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) { - HVX_Vector r0 = Q6_V_vsplat_R(0); - HVX_Vector r1 = Q6_V_vsplat_R(0); - HVX_Vector r2 = Q6_V_vsplat_R(0); - HVX_Vector r3 = Q6_V_vsplat_R(0); - HVX_Vector r4 = Q6_V_vsplat_R(0); - HVX_Vector r5 = Q6_V_vsplat_R(0); - HVX_Vector r6 = Q6_V_vsplat_R(0); - HVX_Vector r7 = Q6_V_vsplat_R(0); + HVX_Vector r0 = Q6_V_vzero(); + HVX_Vector r1 = Q6_V_vzero(); + HVX_Vector r2 = Q6_V_vzero(); + HVX_Vector r3 = Q6_V_vzero(); + HVX_Vector r4 = Q6_V_vzero(); + HVX_Vector r5 = Q6_V_vzero(); + HVX_Vector r6 = Q6_V_vzero(); + HVX_Vector r7 = Q6_V_vzero(); HVX_VectorPair p3; HVX_VectorPair p2; @@ -213,15 +285,42 @@ static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, uns } static inline HVX_Vector hvx_vec_rmpy_x8_full(HVX_Vector_x8 x, HVX_Vector_x8 y) { - return hvx_vec_rmpy_x8_n(x, y, 1024); + HVX_Vector r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]); + HVX_Vector r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]); + HVX_Vector r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]); + HVX_Vector r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]); + HVX_Vector r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]); + HVX_Vector r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]); + HVX_Vector r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]); + HVX_Vector r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]); + + HVX_VectorPair p0 = Q6_W_vdeal_VVR(r1, r0, -4); + HVX_VectorPair p1 = Q6_W_vdeal_VVR(r3, r2, -4); + HVX_VectorPair p2 = Q6_W_vdeal_VVR(r5, r4, -4); + HVX_VectorPair p3 = Q6_W_vdeal_VVR(r7, r6, -4); + + r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); + r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); + r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2)); + r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3)); + + p0 = Q6_W_vdeal_VVR(r1, r0, -4); + p1 = Q6_W_vdeal_VVR(r3, r2, -4); + + r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); + r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); + + p0 = Q6_W_vdeal_VVR(r1, r0, -4); + r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); + + return r0; } -// Handle most common cases of tensors not multiple of 1024. -static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) { - if (n <= 256) { return hvx_vec_rmpy_x8_n(x, y, 256); }; - if (n <= 512) { return hvx_vec_rmpy_x8_n(x, y, 512); }; - if (n <= 768) { return hvx_vec_rmpy_x8_n(x, y, 768); }; - return hvx_vec_rmpy_x8_n(x, y, 1024); +static inline HVX_Vector hvx_vec_rmpy_x8_partial(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) { + if (n >= 512) + return hvx_vec_rmpy_x8_full(x, y); + + return hvx_vec_rmpy_x8_partial(x, y, 512); } static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { @@ -246,7 +345,7 @@ static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_sum = Q6_V_vzero(); // Multiply and accumulate into int32. // Compute combined scale (fp32). @@ -257,12 +356,12 @@ static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo uint32_t i = 0; for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); @@ -272,19 +371,19 @@ static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } - // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks + // Process leftovers if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - // Zero out unused scales + // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r0_ia = Q6_V_vand_QV(bmask, r0_ia); @@ -326,8 +425,8 @@ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); // Multiply and accumulate into int32. // Compute combined scale (fp32). @@ -338,14 +437,14 @@ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, uint32_t i = 0; for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); @@ -359,23 +458,23 @@ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } - // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks + // Process leftovers if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe)); + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - // Zero out unused scales + // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r1_dd = Q6_V_vand_QV(bmask, r1_dd); @@ -423,10 +522,10 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales // Row sums (sf) - 4 accumulators for 2×2 tile - HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); - HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); const uint32_t nb = n / qk; // num full blocks const uint32_t nloe = n % qk; // num leftover elements @@ -434,12 +533,12 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * uint32_t i = 0; for (; i < nb; i++) { // Load src1 columns (reused across both src0 rows) - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); // Load src0 rows (reused across both src1 columns) - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); @@ -448,8 +547,8 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); // Load scales - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); @@ -473,18 +572,18 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * // Process leftovers if (nloe) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe)); + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); @@ -545,7 +644,7 @@ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_sum = Q6_V_vzero(); // Multiply and accumulate into int32. // Compute combined scale (fp32). @@ -556,12 +655,12 @@ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo uint32_t i = 0; for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); @@ -571,19 +670,19 @@ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } - // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks + // Process leftovers if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - // Zero out unused scales + // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r0_ia = Q6_V_vand_QV(bmask, r0_ia); @@ -625,8 +724,8 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (qf32) - HVX_Vector r0_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); // Multiply and accumulate into int32. // Compute combined scale (fp32). @@ -637,14 +736,14 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, uint32_t i = 0; for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); @@ -658,14 +757,14 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } - // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks + // Process leftovers if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe)); + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); @@ -674,7 +773,7 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - // Zero out unused scales + // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r1_dd = Q6_V_vand_QV(bmask, r1_dd); @@ -722,10 +821,10 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales // Row sums (sf) - 4 accumulators for 2×2 tile - HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); - HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); const uint32_t nb = n / qk; // num full blocks const uint32_t nloe = n % qk; // num leftover elements @@ -733,12 +832,12 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * uint32_t i = 0; for (; i < nb; i++) { // Load src1 columns (reused across both src0 rows) - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); // Load src0 rows (reused across both src1 columns) - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size); // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); @@ -747,8 +846,8 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); // Load scales - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); @@ -772,18 +871,18 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * // Process leftovers if (nloe) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe)); + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); @@ -792,7 +891,7 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - // Zero out unused scales + // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); @@ -844,7 +943,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_sum = Q6_V_vzero(); // Multiply and accumulate into int32. // Compute combined scale (fp32). @@ -855,8 +954,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const uint32_t i = 0; for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); @@ -887,12 +986,12 @@ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const // Process leftovers if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving @@ -954,8 +1053,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); // Multiply and accumulate into int32. // Compute combined scale (fp32). @@ -966,9 +1065,9 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, uint32_t i = 0; for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); @@ -1007,14 +1106,14 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, // Process leftovers if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); @@ -1087,10 +1186,10 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales // Row sums (sf) - 4 accumulators for 2×2 tile - HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); - HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); const uint32_t nb = n / qk; // num full blocks const uint32_t nloe = n % qk; // num leftover elements @@ -1098,12 +1197,12 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float uint32_t i = 0; for (; i < nb; i++) { // Load src1 columns (reused across both src0 rows) - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); // Load src0 rows (reused across both src1 columns) - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size); // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); @@ -1157,15 +1256,15 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float // Process leftovers if (nloe) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial( y0_q + i * y_qblk_size, nloe); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial( y1_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe)); + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size); HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size); @@ -1234,7 +1333,7 @@ static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair rsum_p = Q6_W_vzero(); uint32_t i = 0; @@ -1264,8 +1363,8 @@ static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0, uint32_t nvec = n / VLEN_FP16; uint32_t nloe = n % VLEN_FP16; - HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); - HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair rsum0_p = Q6_W_vzero(); + HVX_VectorPair rsum1_p = Q6_W_vzero(); uint32_t i = 0; @@ -1303,10 +1402,10 @@ static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * res uint32_t nloe = n % VLEN_FP16; // Row sums (sf) - 4 accumulators for 2×2 tile - HVX_VectorPair r0_c0_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); - HVX_VectorPair r0_c1_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); - HVX_VectorPair r1_c0_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); - HVX_VectorPair r1_c1_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair r0_c0_sum_p = Q6_W_vzero(); + HVX_VectorPair r0_c1_sum_p = Q6_W_vzero(); + HVX_VectorPair r1_c0_sum_p = Q6_W_vzero(); + HVX_VectorPair r1_c1_sum_p = Q6_W_vzero(); uint32_t i = 0; @@ -1358,7 +1457,7 @@ static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - HVX_Vector rsum = Q6_V_vsplat_R(0); + HVX_Vector rsum = Q6_V_vzero(); uint32_t i = 0; @@ -1388,9 +1487,9 @@ static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - const HVX_Vector zero = Q6_V_vsplat_R(0); + const HVX_Vector zero = Q6_V_vzero(); - HVX_Vector rsum = Q6_V_vsplat_R(0); + HVX_Vector rsum = Q6_V_vzero(); uint32_t i = 0; @@ -1973,7 +2072,7 @@ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restric assert((unsigned long) y_q % 128 == 0); HVX_Vector * vx = (HVX_Vector *) x; - HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector zero = Q6_V_vzero(); // Use reduce max fp32 to find max(abs(e)) first HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); @@ -2034,7 +2133,7 @@ static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restric HVX_Vector * vx = (HVX_Vector *) x; // Load and convert into QF32 - HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector zero = Q6_V_vzero(); HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements @@ -2077,7 +2176,7 @@ static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restric HVX_Vector * vx = (HVX_Vector *) x; // Load and convert into QF32 - HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector zero = Q6_V_vzero(); HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements From 3a6f059909ed5dab8587df5df4120315053d57a4 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Sat, 14 Mar 2026 19:27:52 +0000 Subject: [PATCH 2/7] ci : try to optimize some jobs (#20521) * force arm version to test * run on either x86 or arm if we can help it, this only works for runs without ccache * readd other jobs * remove ccache --- .github/workflows/build-cmake-pkg.yml | 4 ++-- .github/workflows/build-linux-cross.yml | 4 ++-- .github/workflows/build.yml | 32 +++++-------------------- .github/workflows/release.yml | 2 +- .github/workflows/server-webui.yml | 2 +- 5 files changed, 12 insertions(+), 32 deletions(-) diff --git a/.github/workflows/build-cmake-pkg.yml b/.github/workflows/build-cmake-pkg.yml index 259efa43c..84cf8ddf4 100644 --- a/.github/workflows/build-cmake-pkg.yml +++ b/.github/workflows/build-cmake-pkg.yml @@ -5,7 +5,7 @@ on: jobs: linux: - runs-on: ubuntu-24.04 + runs-on: ubuntu-slim steps: - uses: actions/checkout@v6 with: @@ -14,7 +14,7 @@ jobs: - name: Install dependencies run: | sudo apt update - sudo apt install -y build-essential tcl + sudo apt install -y build-essential tcl cmake - name: Build run: | diff --git a/.github/workflows/build-linux-cross.yml b/.github/workflows/build-linux-cross.yml index 8b6ebaf4a..dbcc1ee2a 100644 --- a/.github/workflows/build-linux-cross.yml +++ b/.github/workflows/build-linux-cross.yml @@ -142,7 +142,7 @@ jobs: # cmake --build build --config Release -j $(nproc) debian-13-loongarch64-cpu-cross: - runs-on: ubuntu-24.04 + runs-on: ${{ 'ubuntu-24.04-arm' || 'ubuntu-24.04' }} container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671 steps: @@ -197,7 +197,7 @@ jobs: cmake --build build --config Release -j $(nproc) debian-13-loongarch64-vulkan-cross: - runs-on: ubuntu-24.04 + runs-on: ${{ 'ubuntu-24.04-arm' || 'ubuntu-24.04' }} container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671 steps: diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 102d90455..cfc78643b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -317,8 +317,8 @@ jobs: cd build ctest -L main --verbose --timeout 900 - ubuntu-latest-llguidance: - runs-on: ubuntu-latest + ubuntu-24-llguidance: + runs-on: ${{ 'ubuntu-24.04-arm' || 'ubuntu-24.04' }} steps: - name: Clone @@ -345,8 +345,8 @@ jobs: cd build ctest -L main --verbose --timeout 900 - ubuntu-latest-cmake-rpc: - runs-on: ubuntu-latest + ubuntu-24-cmake-rpc: + runs-on: ${{ 'ubuntu-24.04-arm' || 'ubuntu-24.04' }} continue-on-error: true @@ -355,12 +355,6 @@ jobs: id: checkout uses: actions/checkout@v6 - # - name: ccache - # uses: ggml-org/ccache-action@v1.2.16 - # with: - # key: ubuntu-latest-cmake-rpc - # evict-old-files: 1d - - name: Dependencies id: depends run: | @@ -381,20 +375,13 @@ jobs: ctest -L main --verbose ubuntu-24-cmake-vulkan-deb: - runs-on: ubuntu-24.04 + runs-on: ${{ 'ubuntu-24.04-arm' || 'ubuntu-24.04' }} steps: - name: Clone id: checkout uses: actions/checkout@v6 - - name: ccache - uses: ggml-org/ccache-action@v1.2.16 - with: - key: ubuntu-24-cmake-vulkan-deb - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - name: Dependencies id: depends run: | @@ -545,20 +532,13 @@ jobs: ctest -L main --verbose --timeout 900 ubuntu-24-wasm-webgpu: - runs-on: ubuntu-24.04 + runs-on: ${{ 'ubuntu-24.04-arm' || 'ubuntu-24.04' }} steps: - name: Clone id: checkout uses: actions/checkout@v6 - - name: ccache - uses: ggml-org/ccache-action@v1.2.16 - with: - key: ubuntu-latest-wasm-webgpu - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - name: Install Emscripten run: | git clone https://github.com/emscripten-core/emsdk.git diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 6d5d3774d..1620d9a1b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -952,7 +952,7 @@ jobs: permissions: contents: write # for creating release - runs-on: ubuntu-latest + runs-on: ubuntu-slim needs: - windows diff --git a/.github/workflows/server-webui.yml b/.github/workflows/server-webui.yml index 94899c937..4d560ff58 100644 --- a/.github/workflows/server-webui.yml +++ b/.github/workflows/server-webui.yml @@ -29,7 +29,7 @@ concurrency: jobs: webui-check: name: WebUI Checks - runs-on: ubuntu-latest + runs-on: ${{ 'ubuntu-24.04-arm' || 'ubuntu-24.04' }} continue-on-error: true steps: - name: Checkout code From fc350fdf96d60474378e382a306d58c67633986c Mon Sep 17 00:00:00 2001 From: Gerard Guillemas Martos Date: Sat, 14 Mar 2026 21:37:09 +0100 Subject: [PATCH 3/7] docker : force Python 3.13 in Vulkan container (#20530) * ci: force Python 3.13 in Vulkan container * remove unnecessary `update-alternatives` line --- .devops/vulkan.Dockerfile | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.devops/vulkan.Dockerfile b/.devops/vulkan.Dockerfile index 5d6c87ed6..3112ec85e 100644 --- a/.devops/vulkan.Dockerfile +++ b/.devops/vulkan.Dockerfile @@ -53,10 +53,11 @@ RUN apt-get update \ && apt-get install -y \ build-essential \ git \ - python3 \ - python3-dev \ + python3.13 \ + python3.13-dev \ python3-pip \ python3-wheel \ + && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.13 100 \ && pip install --break-system-packages --upgrade setuptools \ && pip install --break-system-packages -r requirements.txt \ && apt autoremove -y \ From b4768955c47f8052710dbbe0cbf816f9e7aca93d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 14 Mar 2026 23:15:35 +0200 Subject: [PATCH 4/7] ci : move self-hosted workflows to separate files (#20540) --- .github/workflows/build-cache.yml | 2 +- .github/workflows/build-self-hosted.yml | 250 ++++++++++++++++++ .github/workflows/build.yml | 196 +------------- .github/workflows/release.yml | 2 +- ...erver-metal.yml => server-self-hosted.yml} | 45 +++- 5 files changed, 295 insertions(+), 200 deletions(-) create mode 100644 .github/workflows/build-self-hosted.yml rename .github/workflows/{server-metal.yml => server-self-hosted.yml} (56%) diff --git a/.github/workflows/build-cache.yml b/.github/workflows/build-cache.yml index 7cfdaff60..dffbf2b4a 100644 --- a/.github/workflows/build-cache.yml +++ b/.github/workflows/build-cache.yml @@ -67,7 +67,7 @@ jobs: runs-on: ubuntu-24.04 env: - # Sync versions in build.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile + # Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile OPENVINO_VERSION_MAJOR: "2026.0" OPENVINO_VERSION_FULL: "2026.0.0.20965.c6d6a13a886" diff --git a/.github/workflows/build-self-hosted.yml b/.github/workflows/build-self-hosted.yml new file mode 100644 index 000000000..eba06b96b --- /dev/null +++ b/.github/workflows/build-self-hosted.yml @@ -0,0 +1,250 @@ +name: CI (self-hosted) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: [ + '.github/workflows/build.yml', + '**/CMakeLists.txt', + '**/.cmake', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.swift', + '**/*.m', + '**/*.metal', + '**/*.comp', + '**/*.glsl', + '**/*.wgsl' + ] + + pull_request: + types: [opened, synchronize, reopened] + paths: [ + '.github/workflows/build-self-hosted.yml', + '**/CMakeLists.txt', + '**/.cmake', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.swift', + '**/*.m', + '**/*.metal', + '**/*.comp', + '**/*.glsl', + '**/*.wgsl' + ] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +env: + GGML_NLOOP: 3 + GGML_N_THREADS: 1 + LLAMA_LOG_COLORS: 1 + LLAMA_LOG_PREFIX: 1 + LLAMA_LOG_TIMESTAMPS: 1 + +jobs: + ggml-ci-nvidia-cuda: + runs-on: [self-hosted, Linux, NVIDIA] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Test + id: ggml-ci + run: | + nvidia-smi + GG_BUILD_CUDA=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp + + ggml-ci-nvidia-vulkan-cm: + runs-on: [self-hosted, Linux, NVIDIA] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Test + id: ggml-ci + run: | + vulkaninfo --summary + GG_BUILD_VULKAN=1 GGML_VK_DISABLE_COOPMAT2=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp + + ggml-ci-nvidia-vulkan-cm2: + runs-on: [self-hosted, Linux, NVIDIA, COOPMAT2] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Test + id: ggml-ci + run: | + vulkaninfo --summary + GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp + + ggml-ci-cpu-amx: + runs-on: [self-hosted, Linux, CPU, AMX] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Test + id: ggml-ci + run: | + bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp + + # ggml-ci-amd-vulkan: + # runs-on: [self-hosted, Linux, AMD] + + # steps: + # - name: Clone + # id: checkout + # uses: actions/checkout@v6 + + # - name: Test + # id: ggml-ci + # run: | + # vulkaninfo --summary + # GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp + + # ggml-ci-amd-rocm: + # runs-on: [self-hosted, Linux, AMD] + + # steps: + # - name: Clone + # id: checkout + # uses: actions/checkout@v6 + + # - name: Test + # id: ggml-ci + # run: | + # amd-smi static + # GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp + + ggml-ci-mac-metal: + runs-on: [self-hosted, macOS, ARM64] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Test + id: ggml-ci + run: | + GG_BUILD_METAL=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp + + ggml-ci-mac-webgpu: + runs-on: [self-hosted, macOS, ARM64] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Dawn Dependency + id: dawn-depends + run: | + DAWN_VERSION="v2.0.0" + DAWN_OWNER="reeselevine" + DAWN_REPO="dawn" + DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release" + echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip" + curl -L -o artifact.zip \ + "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip" + mkdir dawn + unzip artifact.zip + tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1 + + - name: Test + id: ggml-ci + run: | + GG_BUILD_WEBGPU=1 GG_BUILD_WEBGPU_DAWN_PREFIX="$GITHUB_WORKSPACE/dawn" \ + bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp + + ggml-ci-mac-vulkan: + runs-on: [self-hosted, macOS, ARM64] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Test + id: ggml-ci + run: | + vulkaninfo --summary + GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp + + ggml-ci-linux-intel-vulkan: + runs-on: [self-hosted, Linux, Intel] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + with: + persist-credentials: false + + - name: Test + id: ggml-ci + run: | + vulkaninfo --summary + GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp + + ggml-ci-intel-openvino-gpu-low-perf: + runs-on: [self-hosted, Linux, Intel, OpenVINO] + + env: + # Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile + OPENVINO_VERSION_MAJOR: "2026.0" + OPENVINO_VERSION_FULL: "2026.0.0.20965.c6d6a13a886" + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Use OpenVINO Toolkit Cache + uses: actions/cache@v5 + id: cache-openvino + with: + path: ./openvino_toolkit + key: openvino-toolkit-v${{ env.OPENVINO_VERSION_FULL }}-${{ runner.os }} + + - name: Setup OpenVINO Toolkit + if: steps.cache-openvino.outputs.cache-hit != 'true' + uses: ./.github/actions/linux-setup-openvino + with: + path: ./openvino_toolkit + version_major: ${{ env.OPENVINO_VERSION_MAJOR }} + version_full: ${{ env.OPENVINO_VERSION_FULL }} + + - name: Install OpenVINO dependencies + run: | + cd ./openvino_toolkit + chmod +x ./install_dependencies/install_openvino_dependencies.sh + echo "Y" | sudo -E ./install_dependencies/install_openvino_dependencies.sh + + - name: Test + id: ggml-ci + run: | + source ./openvino_toolkit/setupvars.sh + GG_BUILD_OPENVINO=1 GGML_OPENVINO_DEVICE=GPU GG_BUILD_LOW_PERF=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index cfc78643b..460a77012 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -739,7 +739,7 @@ jobs: runs-on: ${{ fromJSON(matrix.runner) }} env: - # Sync versions in build.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile + # Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile OPENVINO_VERSION_MAJOR: "2026.0" OPENVINO_VERSION_FULL: "2026.0.0.20965.c6d6a13a886" @@ -1646,160 +1646,6 @@ jobs: run: | LLAMA_ARG_THREADS=$(nproc) GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt - ggml-ci-x64-nvidia-cuda: - runs-on: [self-hosted, Linux, X64, NVIDIA] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: Test - id: ggml-ci - run: | - nvidia-smi - GG_BUILD_CUDA=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp - - ggml-ci-x64-nvidia-vulkan-cm: - runs-on: [self-hosted, Linux, X64, NVIDIA] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: Test - id: ggml-ci - run: | - vulkaninfo --summary - GG_BUILD_VULKAN=1 GGML_VK_DISABLE_COOPMAT2=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp - - ggml-ci-x64-nvidia-vulkan-cm2: - runs-on: [self-hosted, Linux, X64, NVIDIA, COOPMAT2] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: Test - id: ggml-ci - run: | - vulkaninfo --summary - GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp - - ggml-ci-x64-cpu-amx: - runs-on: [self-hosted, Linux, X64, CPU, AMX] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: Test - id: ggml-ci - run: | - bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp - - # ggml-ci-x64-amd-vulkan: - # runs-on: [self-hosted, Linux, X64, AMD] - - # steps: - # - name: Clone - # id: checkout - # uses: actions/checkout@v6 - - # - name: Test - # id: ggml-ci - # run: | - # vulkaninfo --summary - # GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp - - # ggml-ci-x64-amd-rocm: - # runs-on: [self-hosted, Linux, X64, AMD] - - # steps: - # - name: Clone - # id: checkout - # uses: actions/checkout@v6 - - # - name: Test - # id: ggml-ci - # run: | - # amd-smi static - # GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp - - ggml-ci-mac-metal: - runs-on: [self-hosted, macOS, ARM64] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: Test - id: ggml-ci - run: | - GG_BUILD_METAL=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp - - ggml-ci-mac-webgpu: - runs-on: [self-hosted, macOS, ARM64] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: Dawn Dependency - id: dawn-depends - run: | - DAWN_VERSION="v2.0.0" - DAWN_OWNER="reeselevine" - DAWN_REPO="dawn" - DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release" - echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip" - curl -L -o artifact.zip \ - "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip" - mkdir dawn - unzip artifact.zip - tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1 - - - name: Test - id: ggml-ci - run: | - GG_BUILD_WEBGPU=1 GG_BUILD_WEBGPU_DAWN_PREFIX="$GITHUB_WORKSPACE/dawn" \ - bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp - - ggml-ci-mac-vulkan: - runs-on: [self-hosted, macOS, ARM64] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: Test - id: ggml-ci - run: | - vulkaninfo --summary - GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp - - ggml-ci-x64-linux-intel-vulkan: - runs-on: [self-hosted, Linux, X64, Intel] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - with: - persist-credentials: false - - - name: Test - id: ggml-ci - run: | - vulkaninfo --summary - GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp - ggml-ci-arm64-cpu-kleidiai: runs-on: ubuntu-22.04-arm @@ -1826,46 +1672,6 @@ jobs: run: | GG_BUILD_KLEIDIAI=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt - ggml-ci-x64-intel-openvino-gpu-low-perf: - runs-on: [self-hosted, Linux, X64, Intel, OpenVINO] - - env: - # Sync versions in build.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile - OPENVINO_VERSION_MAJOR: "2026.0" - OPENVINO_VERSION_FULL: "2026.0.0.20965.c6d6a13a886" - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: Use OpenVINO Toolkit Cache - uses: actions/cache@v5 - id: cache-openvino - with: - path: ./openvino_toolkit - key: openvino-toolkit-v${{ env.OPENVINO_VERSION_FULL }}-${{ runner.os }} - - - name: Setup OpenVINO Toolkit - if: steps.cache-openvino.outputs.cache-hit != 'true' - uses: ./.github/actions/linux-setup-openvino - with: - path: ./openvino_toolkit - version_major: ${{ env.OPENVINO_VERSION_MAJOR }} - version_full: ${{ env.OPENVINO_VERSION_FULL }} - - - name: Install OpenVINO dependencies - run: | - cd ./openvino_toolkit - chmod +x ./install_dependencies/install_openvino_dependencies.sh - echo "Y" | sudo -E ./install_dependencies/install_openvino_dependencies.sh - - - name: Test - id: ggml-ci - run: | - source ./openvino_toolkit/setupvars.sh - GG_BUILD_OPENVINO=1 GGML_OPENVINO_DEVICE=GPU GG_BUILD_LOW_PERF=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt - ubuntu-cpu-cmake-riscv64-native: runs-on: RISCV64 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 1620d9a1b..f32963007 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -238,7 +238,7 @@ jobs: openvino_version: ${{ steps.openvino_version.outputs.value }} env: - # Sync versions in build.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile + # Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile OPENVINO_VERSION_MAJOR: "2026.0" OPENVINO_VERSION_FULL: "2026.0.0.20965.c6d6a13a886" diff --git a/.github/workflows/server-metal.yml b/.github/workflows/server-self-hosted.yml similarity index 56% rename from .github/workflows/server-metal.yml rename to .github/workflows/server-self-hosted.yml index 1d707bef4..a11aea7e8 100644 --- a/.github/workflows/server-metal.yml +++ b/.github/workflows/server-self-hosted.yml @@ -1,4 +1,4 @@ -name: Server-Metal +name: Server (self-hosted) on: workflow_dispatch: # allows manual triggering @@ -14,7 +14,7 @@ on: push: branches: - master - paths: ['.github/workflows/server-metal.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'tools/server/**.*'] + paths: ['.github/workflows/server-self-hosted.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'tools/server/**.*'] env: LLAMA_LOG_COLORS: 1 @@ -28,7 +28,7 @@ concurrency: jobs: server-metal: - runs-on: [self-hosted, macOS, ARM64] + runs-on: [self-hosted, llama-server, macOS, ARM64] name: server-metal (${{ matrix.wf_name }}) strategy: @@ -71,3 +71,42 @@ jobs: pip install -r requirements.txt export ${{ matrix.extra_args }} pytest -v -x -m "not slow" + + server-cuda: + runs-on: [self-hosted, llama-server, Linux, NVIDIA] + + name: server-cuda (${{ matrix.wf_name }}) + strategy: + matrix: + build_type: [Release] + wf_name: ["GPUx1"] + include: + - build_type: Release + extra_args: "LLAMA_ARG_BACKEND_SAMPLING=1" + wf_name: "GPUx1, backend-sampling" + fail-fast: false + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + with: + fetch-depth: 0 + ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} + + - name: Build + id: cmake_build + run: | + cmake -B build -DGGML_SCHED_NO_REALLOC=ON + cmake --build build --config ${{ matrix.build_type }} -j $(sysctl -n hw.logicalcpu) --target llama-server + + - name: Tests + id: server_integration_tests + if: ${{ (!matrix.disabled_on_pr || !github.event.pull_request) }} + run: | + cd tools/server/tests + python3 -m venv venv + source venv/bin/activate + pip install -r requirements.txt + export ${{ matrix.extra_args }} + pytest -v -x -m "not slow" From b30a5fdf370aa4cec02f6bbcc899fbbf3e77e763 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 14 Mar 2026 23:15:47 +0200 Subject: [PATCH 5/7] metal : add FA specialization for HSK = 320, HSV = 256 (#20549) --- ggml/src/ggml-metal/ggml-metal-device.m | 1 + ggml/src/ggml-metal/ggml-metal.metal | 19 +++++++++++++++++++ tests/test-backend-ops.cpp | 10 ++++++---- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index b7d587f3b..82101f471 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1142,6 +1142,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te op->src[0]->ne[0] != 128 && op->src[0]->ne[0] != 192 && op->src[0]->ne[0] != 256 && + op->src[0]->ne[0] != 320 && op->src[0]->ne[0] != 576) { return false; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index d4b129ed7..b2328605d 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -6176,6 +6176,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6190,6 +6191,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #if defined(GGML_METAL_HAS_BF16) @@ -6205,6 +6207,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif @@ -6220,6 +6223,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6234,6 +6238,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6248,6 +6253,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6262,6 +6268,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6276,6 +6283,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES @@ -6846,6 +6854,17 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flas template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index abf914faa..c9896cc11 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8576,11 +8576,12 @@ static std::vector> make_test_cases_eval() { } } - for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) { + for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 320, 576 }) { for (int hsv : { 40, 64, 72, 80, 96, 128, 192, 256, 512 }) { - if (hsk != 192 && hsk != 576 && hsk != hsv) continue; + if (hsk != 192 && hsk != 320 && hsk != 576 && hsk != hsv) continue; if (hsk == 192 && (hsv != 128 && hsv != 192)) continue; if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA + if (hsk == 320 && hsv != 256) continue; // MLA for (bool mask : { true, false } ) { for (bool sinks : { true, false } ) { @@ -8589,12 +8590,13 @@ static std::vector> make_test_cases_eval() { for (float logit_softcap : {0.0f, 10.0f}) { if (hsk != 128 && logit_softcap != 0.0f) continue; for (int nh : { 1, 4 }) { - if (nh == 1 && hsk != 576) continue; // GLM 4.7 Flash + if (nh == 1 && hsk != 320 && hsk != 576) continue; // GLM 4.7 Flash for (int nr3 : { 1, 3, }) { if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes - for (int nr2 : { 1, 4, 12, 20 }) { + for (int nr2 : { 1, 4, 12, 20, 32 }) { if (nr2 == 12 && hsk != 128) continue; if (nr2 == 20 && (nh != 1 || hsk != 576)) continue; + if (nr2 == 32 && (nh != 1 || hsk != 320)) continue; //for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) { for (int kv : { 113, 512, 1024, }) { if (nr2 != 1 && kv != 512) continue; From d23355afc319f598d0e588a2d16a4da82e14ff41 Mon Sep 17 00:00:00 2001 From: Michael Wand Date: Sat, 14 Mar 2026 14:44:42 -0700 Subject: [PATCH 6/7] model : wire up Qwen3.5/Qwen3.5MoE tensors for NVFP4 support (#20506) --- src/llama-model.cpp | 26 ++++++++++++++++++++++++++ src/llama-model.h | 8 ++++++++ src/models/qwen35.cpp | 24 ++++++++++++------------ src/models/qwen35moe.cpp | 29 ++++++++++++++++------------- 4 files changed, 62 insertions(+), 25 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 6cc28eff2..e8e1bbf1c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7462,6 +7462,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!layer.wo_s && layer.wo) { layer.wo_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); } + if (!layer.wqkv_s && layer.wqkv) { + layer.wqkv_s = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_gate_s && layer.wqkv_gate) { + layer.wqkv_gate_s = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } // dense FFN weight scales (per-tensor, shape {1}) if (!layer.ffn_gate_s && layer.ffn_gate) { @@ -7473,6 +7479,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!layer.ffn_up_s && layer.ffn_up) { layer.ffn_up_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED); } + if (!layer.ffn_gate_shexp_s && layer.ffn_gate_shexp) { + layer.ffn_gate_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_shexp_s && layer.ffn_down_shexp) { + layer.ffn_down_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_shexp_s && layer.ffn_up_shexp) { + layer.ffn_up_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } // MoE expert weight scales (per-expert, shape {n_expert}) if (!layer.ffn_gate_exps_s && layer.ffn_gate_exps) { @@ -7484,6 +7499,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!layer.ffn_up_exps_s && layer.ffn_up_exps) { layer.ffn_up_exps_s = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); } + + // recurrent / linear-attention weight scales (per-tensor, shape {1}) + if (!layer.ssm_out_s && layer.ssm_out) { + layer.ssm_out_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_alpha_s && layer.ssm_alpha) { + layer.ssm_alpha_s = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_beta_s && layer.ssm_beta) { + layer.ssm_beta_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } } } diff --git a/src/llama-model.h b/src/llama-model.h index 9a2dacecc..25bf892e7 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -401,9 +401,17 @@ struct llama_layer { struct ggml_tensor * wk_s = nullptr; struct ggml_tensor * wv_s = nullptr; struct ggml_tensor * wo_s = nullptr; + struct ggml_tensor * wqkv_s = nullptr; + struct ggml_tensor * wqkv_gate_s = nullptr; struct ggml_tensor * ffn_gate_s = nullptr; struct ggml_tensor * ffn_up_s = nullptr; struct ggml_tensor * ffn_down_s = nullptr; + struct ggml_tensor * ffn_gate_shexp_s = nullptr; + struct ggml_tensor * ffn_up_shexp_s = nullptr; + struct ggml_tensor * ffn_down_shexp_s = nullptr; + struct ggml_tensor * ssm_out_s = nullptr; + struct ggml_tensor * ssm_alpha_s = nullptr; + struct ggml_tensor * ssm_beta_s = nullptr; // altup & laurel struct ggml_tensor * per_layer_inp_gate = nullptr; diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index e12dad700..3108bf331 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -90,11 +90,11 @@ std::pair llm_build_qwen35::build_qkvz( const int64_t n_seqs = ubatch.n_seqs; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input); + ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input, model.layers[il].wqkv_s); qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs); cb(qkv_mixed, "linear_attn_qkv_mixed", il); - ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input); + ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input, model.layers[il].wqkv_gate_s); cb(z, "z", il); return { qkv_mixed, z }; @@ -123,7 +123,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn( // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention // Qwen3Next uses a single Q projection that outputs query + gate - ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); // [ (n_embd_head * 2) * n_head, n_tokens ] + ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); // [ (n_embd_head * 2) * n_head, n_tokens ] cb(Qcur_full, "Qcur_full", il); ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, @@ -135,10 +135,10 @@ ggml_tensor * llm_build_qwen35::build_layer_attn( Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); cb(Kcur, "Kcur", il); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); cb(Vcur, "Vcur", il); // Apply K normalization @@ -186,7 +186,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn( cur = ggml_mul(ctx0, cur, gate_sigmoid); cb(cur, "attn_gated", il); - cur = build_lora_mm(model.layers[il].wo, cur); + cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); cb(cur, "attn_output", il); return cur; @@ -217,13 +217,13 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( ggml_tensor * qkv_mixed = qkvz.first; ggml_tensor * z = qkvz.second; - ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur); + ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur, model.layers[il].ssm_beta_s); beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); cb(beta, "beta", il); beta = ggml_sigmoid(ctx0, beta); - ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur); + ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur, model.layers[il].ssm_alpha_s); alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); cb(alpha, "alpha", il); @@ -356,7 +356,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( cb(final_output, "final_output", il); // Output projection - cur = build_lora_mm(model.layers[il].ssm_out, final_output); + cur = build_lora_mm(model.layers[il].ssm_out, final_output, model.layers[il].ssm_out_s); cb(cur, "linear_attn_out", il); // Reshape back to original dimensions @@ -370,9 +370,9 @@ ggml_tensor * llm_build_qwen35::build_layer_ffn(ggml_tensor * cur, const int il) GGML_ASSERT(model.layers[il].ffn_gate_inp == nullptr); cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, + model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, NULL, model.layers[il].ffn_down_s, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index 8d07c7ed2..165e2412e 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -90,11 +90,11 @@ std::pair llm_build_qwen35moe::build_qkvz( const int64_t n_seqs = ubatch.n_seqs; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input); + ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input, model.layers[il].wqkv_s); qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs); cb(qkv_mixed, "linear_attn_qkv_mixed", il); - ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input); + ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input, model.layers[il].wqkv_gate_s); cb(z, "z", il); return { qkv_mixed, z }; @@ -123,7 +123,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn( // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention // Qwen3Next uses a single Q projection that outputs query + gate - ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); // [ (n_embd_head * 2) * n_head, n_tokens ] + ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); // [ (n_embd_head * 2) * n_head, n_tokens ] cb(Qcur_full, "Qcur_full", il); ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, @@ -135,10 +135,10 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn( Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); cb(Kcur, "Kcur", il); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); cb(Vcur, "Vcur", il); // Apply K normalization @@ -186,7 +186,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn( cur = ggml_mul(ctx0, cur, gate_sigmoid); cb(cur, "attn_gated", il); - cur = build_lora_mm(model.layers[il].wo, cur); + cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); cb(cur, "attn_output", il); return cur; @@ -217,13 +217,13 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( ggml_tensor * qkv_mixed = qkvz.first; ggml_tensor * z = qkvz.second; - ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur); + ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur, model.layers[il].ssm_beta_s); beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); cb(beta, "beta", il); beta = ggml_sigmoid(ctx0, beta); - ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur); + ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur, model.layers[il].ssm_alpha_s); alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); cb(alpha, "alpha", il); @@ -356,7 +356,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( cb(final_output, "final_output", il); // Output projection - cur = build_lora_mm(model.layers[il].ssm_out, final_output); + cur = build_lora_mm(model.layers[il].ssm_out, final_output, model.layers[il].ssm_out_s); cb(cur, "linear_attn_out", il); // Reshape back to original dimensions @@ -380,16 +380,19 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_ffn(ggml_tensor * cur, const int LLM_FFN_SILU, true, hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, - nullptr, model.layers[il].ffn_gate_up_exps); + nullptr, model.layers[il].ffn_gate_up_exps, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); cb(moe_out, "ffn_moe_out", il); // Add shared experts if present - following Qwen3Next reference implementation if (model.layers[il].ffn_up_shexp != nullptr) { ggml_tensor * ffn_shexp = build_ffn(cur, - model.layers[il].ffn_up_shexp, NULL, NULL, - model.layers[il].ffn_gate_shexp, NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, + model.layers[il].ffn_up_shexp, NULL, model.layers[il].ffn_up_shexp_s, + model.layers[il].ffn_gate_shexp, NULL, model.layers[il].ffn_gate_shexp_s, + model.layers[il].ffn_down_shexp, NULL, model.layers[il].ffn_down_shexp_s, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); cb(ffn_shexp, "ffn_shexp", il); From 6b10a82c00064d4ead889b09d7fae9eff6927d57 Mon Sep 17 00:00:00 2001 From: sprayandwipe Date: Sun, 15 Mar 2026 07:11:19 +0000 Subject: [PATCH 7/7] kv-cache : fix reading llama_kv_cell_ext during state read (#20273) Co-authored-by: sid --- src/llama-kv-cache.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 82fe58fac..01166fac9 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1953,6 +1953,12 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 cells.pos_set(i, pos); + if (hparams.n_pos_per_embd() > 1) { + llama_kv_cell_ext ext; + io.read_to(&ext, sizeof(ext)); + cells.ext_set(i, ext); + } + for (uint32_t j = 0; j < n_seq_id; ++j) { llama_seq_id seq_id; io.read_to(&seq_id, sizeof(seq_id));