diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index e05ccfd5f..3ef0bcdb2 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -201,11 +201,10 @@ static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32 // Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using // full HVX vector width. One vmemu + one vlut16 replaces 4 separate calls. -// Output: out[0..3] each hold 32 FP16 values in the first 64 bytes. -static inline void dequantize_x4x2_q4_0_x4groups_hvx( +// Output: vector_x2 each hold 32 FP16 values in the first 64 bytes. +static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx( const uint8_t *packed_128, bool upper_nibbles, - const __fp16 *scales_4, const HVX_Vector vlut_cvt, - HVX_Vector out[4]) { + const __fp16 *scales_4, const HVX_Vector vlut_cvt) { // Load all 128 packed bytes (4 contiguous 32-byte groups) HVX_Vector vq = hvx_vmemu(packed_128); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); @@ -221,8 +220,7 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16] // Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b - volatile HVX_Vector vscale = hvx_vmemu(scales_4); - + HVX_Vector vscale = hvx_vmemu(scales_4); HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale); HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4)); @@ -230,8 +228,9 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); // Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter - out[0] = v_lo; // group0 already in [0:63] - out[1] = v_hi; // group2 already in [0:63] + HVX_Vector_x2 r = { v_lo,/* group1 already in [0:63] */ + v_hi /* group2 already in [0:63] */ }; + return r; } // Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. @@ -292,12 +291,11 @@ static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx(const uint8_t * packed } // Batch-dequantize 4 contiguous x4x2 MXFP4 groups (4x32 = 128 packed bytes). -static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128, +static inline HVX_Vector_x4 dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128, bool upper_nibbles, int sub_blk_base, const HVX_Vector vlut_cvt, - mxfp4_scales_t scales, - HVX_Vector out[4]) { + mxfp4_scales_t scales) { HVX_Vector vq = hvx_vmemu(packed_128); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; @@ -318,10 +316,8 @@ static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_12 v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); - out[0] = v_lo; - out[1] = Q6_V_vror_VR(v_lo, 64); - out[2] = v_hi; - out[3] = Q6_V_vror_VR(v_hi, 64); + HVX_Vector_x4 r = { v_lo, Q6_V_vror_VR(v_lo, 64), v_hi, Q6_V_vror_VR(v_hi, 64) }; + return r; } // Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16. @@ -372,18 +368,18 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { - HVX_Vector v0[2]; const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; - dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]); + const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; + + HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); + HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); + + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - - r0 = vtcm_src + row_offset; row_offset += row_stride; - dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]); + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); } @@ -415,21 +411,21 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( // Batch-convert all 8 E8M0 scales once per row (stays in HVX register) mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); - HVX_Vector v0[4], v1[4]; - dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8, v0); + HVX_Vector_x4 dv0, dv1; + dv0 = dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8); if (row1 < n_cols) { mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); - dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8, v1); + dv1 = dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8); } else { - v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero(); + dv1.v[0] = dv1.v[1] = dv1.v[2] = dv1.v[3] = Q6_V_vzero(); } for (int g = 0; g < 4; g++) { - Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[g]); } v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); for (int g = 0; g < 4; g++) { - Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[g]); } v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); } @@ -612,11 +608,13 @@ static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict const __fp16 *row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS; const __fp16 *col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS; - for (int k = 0; k < n_dot_tiles; ++k) { - Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047); - Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047); - row_tiles += HMX_FP16_TILE_N_ELMS; - col_tiles += HMX_FP16_TILE_N_ELMS; + for (int k = 0, k_block; k < n_dot_tiles; k += k_block) { + k_block = hex_smin(n_dot_tiles - k, 32); + const uint32_t range = 2048u * (uint32_t)k_block - 1; + Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); + row_tiles += k_block * HMX_FP16_TILE_N_ELMS; + col_tiles += k_block * HMX_FP16_TILE_N_ELMS; } __fp16 *out_tile = output + (r * n_col_tiles + c) * HMX_FP16_TILE_N_ELMS; @@ -832,10 +830,6 @@ static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 * worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads); } -// - -#define FALLBACK_TO_STANDARD 1 - // C += AB static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile, @@ -861,314 +855,80 @@ static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, co Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); } - for (int k = 0; k < n_dot_tiles; ++k) { - Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047); - Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047); - row_tiles += HMX_FP16_TILE_N_ELMS; - col_tiles += HMX_FP16_TILE_N_ELMS; + for (int k = 0, k_block; k < n_dot_tiles; k += k_block) { + k_block = hex_smin(n_dot_tiles - k, 32); + const uint32_t range = 2048u * (uint32_t)k_block - 1; + Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); + row_tiles += k_block * HMX_FP16_TILE_N_ELMS; + col_tiles += k_block * HMX_FP16_TILE_N_ELMS; } + Q6_mxmem_AR_after_hf(accum_tile, 0); } } } -static __attribute__((noinline)) int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, - float *restrict out, const float *restrict x, const uint8_t *restrict w, - int m, int k, int n, int weight_type) { - // assume k % 32 == 0 && n % 32 == 0 - const size_t row_stride = get_x4x2_row_stride(weight_type, k); - if (row_stride == 0) { - return -1; - } - - const size_t vtcm_budget = ctx->vtcm_size; - - const size_t K_BLOCK_SIZE = 1024; - - // Fallback: if k doesn't need K-blocking, out-stationary has no advantage - const size_t k_iters_check = (k + K_BLOCK_SIZE - 1) / K_BLOCK_SIZE; - if (k_iters_check <= 1) { - FARF(HIGH, "%s: K_BLK=%zu >= k=%d, fallback to standard path", __func__, K_BLOCK_SIZE, k); - return FALLBACK_TO_STANDARD; - } - - // Dynamic M,N search via hmx_compute_chunks - const size_t sub_row_stride_alloc = get_x4x2_row_stride(weight_type, K_BLOCK_SIZE); - const size_t per_m = K_BLOCK_SIZE * sizeof(float) // scratch1: M×K×4 (act DMA staging F32) - + K_BLOCK_SIZE * sizeof(__fp16); // activation: M×K×2 (F16 tiles) - const size_t per_n = sub_row_stride_alloc // scratch0: N×sub_row(K) (packed quant) - + K_BLOCK_SIZE * sizeof(__fp16); // weight: N×K×2 (F16 tiles) - const size_t per_mn = sizeof(__fp16); // output: M×N×2 (out-stationary) - - // Alignment margin: hex_align_up can add up to 2047 bytes per buffer; - // scratch1 (mc×6144) is naturally 2048-aligned, remaining 4 buffers need margin - const size_t align_margin = 4 * HMX_FP16_TILE_SIZE; - const size_t overhead = HMX_FP16_TILE_SIZE + 256 + align_margin; // eye_tile + scales + alignment - - size_t M_BLOCK_SIZE, N_BLOCK_SIZE, vtcm_used; - // Cost-based search: minimize ceil(m/mc)*m_block_cost + ceil(n/nc)*n_block_cost. - // From profiling: wt_dequant per element ≈ 1.5× activation load per element. - // m_block_cost = n*3: each extra M-block re-dequants all N×K weight (expensive). - // n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper). - const size_t m_block_cost = (size_t) n * 3; - const size_t n_block_cost = (size_t) m * 2; - if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - m_block_cost, n_block_cost, &M_BLOCK_SIZE, - &N_BLOCK_SIZE, &vtcm_used) != 0) { - FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); - return -1; - } - - // Compute precise buffer sizes from searched M,N and fixed K - const size_t weight_size = hex_align_up(N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t act_size = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t out_size = hex_align_up(M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t scratch0_sz = hex_align_up(N_BLOCK_SIZE * sub_row_stride_alloc, HMX_FP16_TILE_SIZE); - const size_t scratch1_sz = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(float), HMX_FP16_TILE_SIZE); - - const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256; - if (total_vtcm > vtcm_budget) { - FARF(HIGH, "%s: VTCM overflow after search: need %zu have %zu (M=%zu N=%zu K=%zu)", __func__, total_vtcm, - vtcm_budget, M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE); - return -1; - } - - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, out_size); - uint8_t *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_sz); - uint8_t *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_sz); - __fp16 *vtcm_eye_tile = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, HMX_FP16_TILE_SIZE); - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget); - - FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", m, k, n, weight_type, - M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE, (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); - - // initialize eye tile (32x32 identity matrix) - { - HVX_Vector v; - v = Q6_V_vzero(); - v = Q6_Vw_vinsert_VwR(v, 0x3c000000); - v = Q6_V_vror_VR(v, VLEN - 4); - v = Q6_Vw_vinsert_VwR(v, 0x00003c00); - for (int i = 0; i < 16; ++i) { - ((HVX_Vector *) vtcm_eye_tile)[i] = v; - v = Q6_V_vror_VR(v, VLEN - 8); - } - } - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - - TIMER_DEFINE(fetch); - TIMER_DEFINE(act_load); - TIMER_DEFINE(wt_dequant); - TIMER_DEFINE(core); - - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - - for (size_t mr = 0; mr < m; mr += M_BLOCK_SIZE) { - size_t m_blk_sz = hex_smin(m - mr, M_BLOCK_SIZE); - for (size_t nc = 0; nc < n; nc += N_BLOCK_SIZE) { - size_t n_blk_sz = hex_smin(n - nc, N_BLOCK_SIZE); - - const int n_row_tiles = hmx_ceil_div(m_blk_sz, HMX_FP16_TILE_N_ROWS); - const int n_col_tiles = hmx_ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS); - - for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) { - const size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE); - - TIMER_START(fetch); - // fetch activation block into VTCM - { - const float *activation_block = x + mr * k + kk; - - dma_queue_push(ctx->dma[0], - dma_make_ptr(vtcm_scratch1, activation_block), - k_blk_sz * sizeof(float), - k * sizeof(float), - k_blk_sz * sizeof(float), - m_blk_sz); - } - - // fetch weight block into VTCM (x4x2 sub-block: quants + scales) - const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); - { - const int blk_start = kk / QK_Q4_0x4x2; - const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; - const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2); - const int scale_blk_size = (weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE; - uint8_t *dst = vtcm_scratch0; - const uint8_t *src = w + nc * row_stride; - const size_t n_rows = n_blk_sz; - const size_t src_stride = row_stride; - const size_t dst_stride = sub_row_stride; - const size_t quant_off = (weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2)); - const size_t quant_width = (weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2)); - const size_t scale_off = full_qrow + blk_start * scale_blk_size; - const size_t scale_width = nb_sub * scale_blk_size; - - // 2D DMA: quants sub-range - dma_queue_push(ctx->dma[0], dma_make_ptr(dst, src + quant_off), dst_stride, src_stride, quant_width, n_rows); - // 2D DMA: scales sub-range - dma_queue_push(ctx->dma[0], dma_make_ptr(dst + quant_width, src + scale_off), dst_stride, src_stride, scale_width, n_rows); - } - TIMER_STOP(fetch); - - TIMER_START(act_load); - // load activation block - { - dma_queue_pop(ctx->dma[0]); // wait for act DNA - transfer_activation_chunk_threaded(ctx, vtcm_activation, (float *) vtcm_scratch1, m_blk_sz, k_blk_sz, k_blk_sz); - } - TIMER_STOP(act_load); - - TIMER_START(wt_dequant); - // dequantize weight block - { - dma_queue_pop(ctx->dma[0]); - dma_queue_pop(ctx->dma[0]); - // vtcm_scratch0 is used to store the qweight chunk - // worker_pool_run_func already returned, so fetch is done - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, vtcm_scratch0, - n_blk_sz, k_blk_sz, sub_row_stride, weight_type); - } - TIMER_STOP(wt_dequant); - - // core mma - TIMER_START(core); - { - core_mma_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, vtcm_eye_tile, n_row_tiles, - n_col_tiles, k_blk_sz / HMX_FP16_TILE_N_COLS, kk == 0); - } - TIMER_STOP(core); - } - - // store output block - { - float *output_block = out + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_block, vtcm_output, m_blk_sz, n_blk_sz, n); - } - } - } - - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - -#if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "fetch: %lld us, act_load: %lld us, wt_dequant: %lld us, core: %lld us", - TIMER_US(fetch), TIMER_US(act_load), TIMER_US(wt_dequant), TIMER_US(core)); -#endif - return 0; -} - -int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, +int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const uint8_t *restrict permuted_weight, int m, int k, int n, int weight_type) { - if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } if (k % 32 != 0 || n % 32 != 0) { return -1; } if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { return -1; } - // for large m, k (e.g. prefill FFN Down), use out-stationary version - if (m >= 128 && k > n && n > 1024) { - int rc = mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type); - if (rc != FALLBACK_TO_STANDARD) { - return rc; // 0 success, -1 error - } - FARF(HIGH, "hmx_matmul_qk: out-stationary fallback to standard m=%d k=%d n=%d", m, k, n); - // fall through to standard path - } - size_t row_stride = get_x4x2_row_stride(weight_type, k); if (row_stride == 0) { return -1; } - FARF(HIGH, "hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d", m, k, n, weight_type); - // --- Dynamic VTCM layout --- - const size_t vtcm_budget = ctx->vtcm_size; - const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vtcm_budget = ctx->vtcm_size; + size_t vtcm_used = 0; // Pipeline = 4-stage DMA→dequant→HMX→store with HMX worker overlap. - // Only pays off when the chunker yields >=2 n-chunks, so the main loop can - // overlap HMX (C) with HVX (B/D); with a single n-chunk the extra VTCM for - // double-buffered output and the worker-dispatch overhead are pure loss. - // Try pipeline costs first; fall back to sequential if the layout collapses - // to one n-chunk. m >= 128 floor keeps HMX utilization reasonable. - const size_t pipe_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) - const size_t pipe_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer) - const size_t seq_per_n = vec_dot_size + 2 * row_stride; // W + S0 + S1 (x4x2 DMA bufs) - const size_t seq_per_mn = sizeof(__fp16); // O x 1 + const size_t size_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) + const size_t size_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer) - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - bool use_pipeline = false; - - if (m >= 128) { - size_t mc = 0, nc = 0, used = 0; - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn, - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - /*m_block_cost=*/(size_t) n * 3, - /*n_block_cost=*/(size_t) m * 2, &mc, &nc, &used) == 0 && - hmx_ceil_div((size_t) n, nc) >= 2) { - m_chunk_n_rows = mc; - n_chunk_n_cols = nc; - vtcm_used = used; - use_pipeline = true; - } + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, + hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, + /*m_block_cost=*/(size_t) n * 3, + /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { + FARF(HIGH, "hmx-mm-q: VTCM too small : m %d k %d n %d budget %zu", m, k, n, vtcm_budget); + return -1; } - if (!use_pipeline) { - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn, - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - /*m_block_cost=*/(size_t) n * 3, - /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { - FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); - return -1; - } - } - - // Compute precise buffer sizes per execution path - const size_t weight_area_size = hex_align_up( - n_chunk_n_cols * (use_pipeline ? row_stride : vec_dot_size), HMX_FP16_TILE_SIZE); - const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up( - m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); + const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); size_t scratch0_size, scratch1_size, scratch2_size; - if (use_pipeline) { - scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 - scratch1_size = scratch0_size; // dequant buf 1 - scratch2_size = output_area_size; // output buf 1 - } else { - scratch0_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); // x4x2 DMA buf 0 - scratch1_size = scratch0_size; // x4x2 DMA buf 1 - scratch2_size = 0; // unused - } + scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 + scratch1_size = scratch0_size; // dequant buf 1 + scratch2_size = output_area_size; // output buf 1 uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_size); void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { - FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + + vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; + if (vtcm_used > vtcm_budget) { + FARF(ERROR, "hmx-mm-q: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); return -1; } hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - FARF(HIGH, "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu", - __func__, m, k, n, weight_type, use_pipeline, - m_chunk_n_rows, n_chunk_n_cols, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + FARF(HIGH, "hmx-mm-q: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu", + m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget); TIMER_DEFINE(activation_load); TIMER_DEFINE(weight_load); @@ -1178,184 +938,115 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds TIMER_DEFINE(total); TIMER_START(total); - FARF(HIGH, "hmx_matmul_qk: %s mc=%zu nc=%zu vtcm=%zu/%zu", - use_pipeline ? "PIPELINE" : "SEQUENTIAL", m_chunk_n_rows, n_chunk_n_cols, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + // 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D) + // HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D). - if (!use_pipeline) { - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - // transfer activation matrix chunk into VTCM - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); + // A --> B: vtcm_qweight, 1 buffer + // B --> C: vtcm_weight0/vtcm_weight1, 2 buffers + // C --> D: vtcm_output0/vtcm_output1, 2 buffers - TIMER_START(activation_load); - { - const float *activation_chunk = activation + mr * k; - transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); - } - TIMER_STOP(activation_load); + // Async timeline (C overlaps B+D): + // main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2] + // HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████] - void *buf_curr = vtcm_scratch0; - void *buf_next = vtcm_scratch1; + int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); + hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors - { - const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first); - } + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); + void *vtcm_qweight = vtcm_weight; + void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; + void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; - TIMER_START(weight_load); - { - dma_queue_pop(ctx->dma[0]); // wait until current weight chunk become ready - - const size_t nc_next = nc + n_chunk_n_cols; - if (nc_next < n) { - const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); - - const uint8_t *next_weight_chunk = permuted_weight + nc_next * row_stride; - - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), row_stride, row_stride, row_stride, n_cols_next); - } - - // Dequant + vscatter writes directly to [K, N] transposed tiles. - // HMX computes C = A x B, where A=[M,K] activation, B=[K,N] weight. - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, buf_curr, n_cols, k, row_stride, weight_type); - - hex_swap_ptr(&buf_curr, &buf_next); - } - TIMER_STOP(weight_load); - - TIMER_START(hmx_core); - { - core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); - } - TIMER_STOP(hmx_core); - - TIMER_START(output_store); - { - float *output = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n); - } - TIMER_STOP(output_store); - } + // prologue: A0 + const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); + { + const uint8_t *qweight_chunk_A0 = permuted_weight; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0); } - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - } else { - // 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D) - // HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D). - // A --> B: vtcm_qweight, 1 buffer - // B --> C: vtcm_weight0/vtcm_weight1, 2 buffers - // C --> D: vtcm_output0/vtcm_output1, 2 buffers + { + const float *activation_chunk = activation + mr * k; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); + } - // Async timeline (C overlaps B+D): - // main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2] - // HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████] + // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) + { + // B0: wait for DMA, dequant weight chunk 0 + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type); - int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); - hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors - - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - - void *vtcm_qweight = vtcm_weight; - void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; - void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; - - // prologue: A0 - const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); - { - // Use 2D DMA (n_cols rows x row_stride) to avoid 16-bit roiwidth overflow. - const uint8_t *qweight_chunk_A0 = permuted_weight; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0); + // A1: issue DMA for weight chunk 1 + const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); + if (1 < n_chunk_cnt) { + const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); } - { - const float *activation_chunk = activation + mr * k; - transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); - } + // submit C0 (non-blocking — HMX worker executes in parallel) + hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, + (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, + hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); - // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) - { - // B0: wait for DMA, dequant weight chunk 0 + // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) + if (1 < n_chunk_cnt) { dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type); - - // A1: issue DMA for weight chunk 1 - const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); - if (1 < n_chunk_cnt) { - const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); - } - - // submit C0 (non-blocking — HMX worker executes in parallel) - hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, - (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, - hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); - - // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) - if (1 < n_chunk_cnt) { - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type); - } - } - - // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) - for (int i = 0; i < n_chunk_cnt; ++i) { - const size_t nc = i * n_chunk_n_cols; - const size_t nc_p1 = nc + 1 * n_chunk_n_cols; - const size_t nc_p2 = nc + 2 * n_chunk_n_cols; - - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); - const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); - - // issue A_{i+2}: DMA push (non-blocking) - if (i + 2 < n_chunk_cnt) { - const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); - } - - // wait C_i: block until prologue/previous C completes - hmx_queue_pop(ctx->hmx_queue); - - // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) - // job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's - // counterpart — and (i+1)%2 was last used by C_{i-1} which completed - // before C_i was submitted. - if (i + 1 < n_chunk_cnt) { - hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], - (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], - vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); - } - - // D_i: store output (multi-thread HVX, parallel with C_{i+1}) - float *output_chunk = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n); - - // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) - if (i + 2 < n_chunk_cnt) { - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type); - } + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type); } } - hmx_queue_suspend(ctx->hmx_queue); + // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) + for (int i = 0; i < n_chunk_cnt; ++i) { + const size_t nc = i * n_chunk_n_cols; + const size_t nc_p1 = nc + 1 * n_chunk_n_cols; + const size_t nc_p2 = nc + 2 * n_chunk_n_cols; + + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); + const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); + + // issue A_{i+2}: DMA push (non-blocking) + if (i + 2 < n_chunk_cnt) { + const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); + } + + // wait C_i: block until prologue/previous C completes + hmx_queue_pop(ctx->hmx_queue); + + // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) + // job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's + // counterpart — and (i+1)%2 was last used by C_{i-1} which completed + // before C_i was submitted. + if (i + 1 < n_chunk_cnt) { + hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], + (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], + vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); + } + + // D_i: store output (multi-thread HVX, parallel with C_{i+1}) + float *output_chunk = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n); + + // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) + if (i + 2 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type); + } + } } + hmx_queue_suspend(ctx->hmx_queue); + TIMER_STOP(total); #if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d pipeline=%d", __func__, TIMER_US(total), m, k, n, use_pipeline); + FARF(HIGH, "hex-mm-q: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n); if (!use_pipeline) { FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); @@ -1370,15 +1061,15 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds // -static inline int hmx_matmul_batch_r2(const hmx_matmul_w16a32_batched_params_t *params) { +static inline int hmx_matmul_batch_r2(const hmx_matmul_f16_f32_batched_params_t *params) { return params->ne02 > 0 ? params->ne12 / params->ne02 : 1; } -static inline int hmx_matmul_batch_r3(const hmx_matmul_w16a32_batched_params_t *params) { +static inline int hmx_matmul_batch_r3(const hmx_matmul_f16_f32_batched_params_t *params) { return params->ne03 > 0 ? params->ne13 / params->ne03 : 1; } -static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, +static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, int dst_b2, int dst_b3) { const int r2 = hmx_matmul_batch_r2(params); const int r3 = hmx_matmul_batch_r3(params); @@ -1387,37 +1078,36 @@ static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_ (size_t) (dst_b3 / r3) * params->src0_nb3); } -static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, +static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, int dst_b2, int dst_b3) { return (const float *) ((const uint8_t *) params->activation + (size_t) dst_b2 * params->src1_nb2 + (size_t) dst_b3 * params->src1_nb3); } -static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, +static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, int dst_b2, int dst_b3) { return (float *) ((uint8_t *) params->dst + (size_t) dst_b2 * params->dst_nb2 + (size_t) dst_b3 * params->dst_nb3); } -static int hmx_mat_mul_permuted_w16a32_batched_legacy(struct htp_context *ctx, - const hmx_matmul_w16a32_batched_params_t *params) { +static int hmx_matmul_f16_f32_batched_legacy(struct htp_context *ctx, + const hmx_matmul_f16_f32_batched_params_t *params) { int ret = 0; for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) { for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) { - ret = hmx_mat_mul_permuted_w16a32(ctx, - hmx_matmul_dst_batch_ptr(params, b2, b3), - hmx_matmul_activation_batch_ptr(params, b2, b3), - hmx_matmul_weight_batch_ptr(params, b2, b3), - params->m, params->k, params->n, - params->act_stride, params->weight_stride); + ret = hmx_matmul_f16_f32(ctx, hmx_matmul_dst_batch_ptr(params, b2, b3), + hmx_matmul_activation_batch_ptr(params, b2, b3), + hmx_matmul_weight_batch_ptr(params, b2, b3), + params->m, params->k, params->n, + params->act_stride, params->weight_stride); } } return ret; } -int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmul_w16a32_batched_params_t *params) { +int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params) { if (!ctx || !params || !params->dst || !params->activation || !params->permuted_weight) { return -1; } if (!params->m || !params->k || !params->n) { return -1; } if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; } @@ -1435,7 +1125,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu if (group_size <= 1) { FARF(HIGH, "%s: no dim2 GQA reuse (group=%d), using legacy batched loop", __func__, group_size); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + return hmx_matmul_f16_f32_batched_legacy(ctx, params); } // Grouped path: reuse interleaved weight across all q_heads sharing a @@ -1464,7 +1154,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu /*m_block_cost=*/(size_t) params->n, /*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + return hmx_matmul_f16_f32_batched_legacy(ctx, params); } const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads @@ -1486,7 +1176,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) { FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to legacy batched loop", __func__); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + return hmx_matmul_f16_f32_batched_legacy(ctx, params); } hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 @@ -1614,7 +1304,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu // -int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, +int hmx_matmul_f16_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const __fp16 *restrict permuted_weight, int m, int k, int n, int act_stride, int weight_stride) { if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.h b/ggml/src/ggml-hexagon/htp/hmx-ops.h index 1c78ffadd..f114edb82 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-ops.h +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.h @@ -33,14 +33,14 @@ typedef struct { size_t src1_nb3; size_t dst_nb2; size_t dst_nb3; -} hmx_matmul_w16a32_batched_params_t; +} hmx_matmul_f16_f32_batched_params_t; // HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output // act_stride: activation row stride in elements (= k for contiguous, or // nb[1]/sizeof(float) for permuted tensors like attention Q). // weight_stride: weight row stride in elements (= k for compact weights, or // nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK). -int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, +int hmx_matmul_f16_f32(struct htp_context *ctx, float *restrict dst, const float *activation, const __fp16 *permuted_weight, @@ -48,13 +48,12 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, int act_stride, int weight_stride); -// Batched F16 wrapper over hmx_mat_mul_permuted_w16a32. +// Batched F16 wrapper over hmx_mat_mul_f16_f32. // Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3. -int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, - const hmx_matmul_w16a32_batched_params_t *params); +int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params); -// HMX matrix multiplication — tile-permuted quantised weights (Q4_0/Q8_0/IQ4_NL) -int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, +// HMX matrix multiplication — quantised weights (Q4_0/Q8_0/IQ4_NL/MXFP4) +int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *activation, const uint8_t *permuted_weight, diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 8e54536f6..e86193884 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -87,6 +87,27 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { } } +#if __HVX_ARCH__ >= 75 + { + // Power on HMX and set HMX clock + HAP_power_request_t request; + memset(&request, 0, sizeof(HAP_power_request_t)); + request.type = HAP_power_set_HMX_v2; + request.hmx_v2.set_power = TRUE; + request.hmx_v2.power_up = TRUE; + request.hmx_v2.set_clock = TRUE; + request.hmx_v2.target_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.min_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.max_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.perf_mode = HAP_CLK_PERF_HIGH; + FARF(ALWAYS, "Setting HMX clock\n"); + err = HAP_power_set((void *) ctx, &request); + if (err != AEE_SUCCESS) { + FARF(ERROR, "Error setting HMX clock."); + return err; + } + } +#else { // Power on HMX HAP_power_request_t request; @@ -94,31 +115,12 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { request.type = HAP_power_set_HMX; request.hmx.power_up = TRUE; FARF(ALWAYS, "Powering HMX on\n"); - err = HAP_power_set((void *) &ctx, &request); + err = HAP_power_set((void *) ctx, &request); if (err != AEE_SUCCESS) { FARF(ERROR, "Error powering on HMX."); return err; } } - -#if __HVX_ARCH__ >= 75 - { - // Set HMX clock - HAP_power_request_t request; - memset(&request, 0, sizeof(HAP_power_request_t)); - request.type = HAP_power_set_HMX_v2; - request.hmx_v2.set_clock = TRUE; - request.hmx_v2.target_corner = HAP_DCVS_EXP_VCORNER_MAX; - request.hmx_v2.min_corner = HAP_DCVS_EXP_VCORNER_MAX; - request.hmx_v2.max_corner = HAP_DCVS_EXP_VCORNER_MAX; - request.hmx_v2.perf_mode = HAP_CLK_PERF_HIGH; - FARF(ALWAYS, "Setting HMX clock\n"); - err = HAP_power_set((void *) &ctx, &request); - if (err != AEE_SUCCESS) { - FARF(ERROR, "Error setting HMX clock."); - return err; - } - } #endif return AEE_SUCCESS; diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 2461ae617..46fc5602d 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -2995,7 +2995,6 @@ int op_matmul(struct htp_ops_context * octx) { // is handled by HMX itself; when M < 32 fall back to HVX. const int m_total = (int) src1->ne[1]; const int m_hmx = m_total & ~31; // 0 when M < 32 - if (m_hmx == 0) { return op_matmul_hvx(octx); } @@ -3020,7 +3019,7 @@ int op_matmul(struct htp_ops_context * octx) { if (src0->type == HTP_TYPE_F16) { if (is_batched) { - hmx_matmul_w16a32_batched_params_t batch_params = { + hmx_matmul_f16_f32_batched_params_t batch_params = { .dst = (float *) dst->data, .activation = (float *) src1->data, .permuted_weight = (const __fp16 *) src0->data, @@ -3041,15 +3040,14 @@ int op_matmul(struct htp_ops_context * octx) { .dst_nb2 = dst->nb[2], .dst_nb3 = dst->nb[3], }; - ret = hmx_mat_mul_permuted_w16a32_batched(octx->ctx, &batch_params); + ret = hmx_matmul_f16_f32_batched(octx->ctx, &batch_params); } else { - ret = hmx_mat_mul_permuted_w16a32(octx->ctx, + ret = hmx_matmul_f16_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data, m_total, k, n, act_stride, wgt_stride); } } else { - ret = hmx_mat_mul_permuted_qk_0_d16a32(octx->ctx, - (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, + ret = hmx_matmul_q_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, m_total, k, n, (int) src0->type); } diff --git a/scripts/snapdragon/adb/run-bench.sh b/scripts/snapdragon/adb/run-bench.sh index 27459df24..bbe7146b4 100755 --- a/scripts/snapdragon/adb/run-bench.sh +++ b/scripts/snapdragon/adb/run-bench.sh @@ -45,5 +45,5 @@ adb $adbserial $adbhost shell " \ ADSP_LIBRARY_PATH=$basedir/$branch/lib \ $ndev $nhvx $opmask $verbose $profile $hb ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \ --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ - --ubatch-size 256 -fa 1 -ngl 99 $cli_opts $@ \ + --ubatch-size 1024 -fa 1 -ngl 99 $cli_opts $@ \ " diff --git a/scripts/snapdragon/adb/run-cli.sh b/scripts/snapdragon/adb/run-cli.sh index 3f89a7777..48127dfa2 100755 --- a/scripts/snapdragon/adb/run-cli.sh +++ b/scripts/snapdragon/adb/run-cli.sh @@ -73,6 +73,6 @@ adb $adbserial $adbhost shell " \ $verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $opflt $vmem $mbuf \ ./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \ --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ - --ctx-size 8192 --ubatch-size 256 -fa on \ + --ctx-size 8192 --ubatch-size 1024 -fa on \ -ngl 99 --device $device $cli_opts $@ \ " diff --git a/scripts/snapdragon/adb/run-completion.sh b/scripts/snapdragon/adb/run-completion.sh index 1b904c77b..0d9b10271 100755 --- a/scripts/snapdragon/adb/run-completion.sh +++ b/scripts/snapdragon/adb/run-completion.sh @@ -69,6 +69,6 @@ adb $adbserial $adbhost shell " \ $verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $opflt $vmem $mbuf \ ./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \ --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ - --ctx-size 8192 --ubatch-size 256 -fa on \ - -ngl 99 --device $device $cli_opts $@ \ + --ctx-size 8192 --ubatch-size 1024 -fa on \ + -ngl 99 --device $device $cli_opts $@ \ " diff --git a/scripts/snapdragon/adb/run-mtmd.sh b/scripts/snapdragon/adb/run-mtmd.sh index 38467beba..992045cb9 100755 --- a/scripts/snapdragon/adb/run-mtmd.sh +++ b/scripts/snapdragon/adb/run-mtmd.sh @@ -66,6 +66,6 @@ adb $adbserial $adbhost shell " \ --mmproj $basedir/../gguf/$mmproj \ --image $basedir/../gguf/$image \ --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ - --ctx-size 8192 --ubatch-size 256 -fa on \ + --ctx-size 8192 --ubatch-size 1024 -fa on \ -ngl 99 --device $device -v $cli_opts $@ \ " diff --git a/scripts/snapdragon/windows/run-bench.ps1 b/scripts/snapdragon/windows/run-bench.ps1 index 8bf6939d2..5ee81df68 100644 --- a/scripts/snapdragon/windows/run-bench.ps1 +++ b/scripts/snapdragon/windows/run-bench.ps1 @@ -45,4 +45,4 @@ $env:ADSP_LIBRARY_PATH="$basedir\lib" & "$basedir\bin\llama-bench.exe" ` --mmap 0 -m $basedir\..\..\gguf\$model ` --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 ` - --batch-size 128 -ngl 99 --device $device $cli_opts + --ubatch-size 1024 -ngl 99 --device $device $cli_opts diff --git a/scripts/snapdragon/windows/run-cli.ps1 b/scripts/snapdragon/windows/run-cli.ps1 index 104452f9b..b51149bec 100644 --- a/scripts/snapdragon/windows/run-cli.ps1 +++ b/scripts/snapdragon/windows/run-cli.ps1 @@ -49,5 +49,5 @@ $env:ADSP_LIBRARY_PATH="$basedir\lib" & "$basedir\bin\llama-cli.exe" ` --no-mmap -m $basedir\..\..\gguf\$model ` --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 ` - --ctx-size 8192 --ubatch-size 256 -fa on ` + --ctx-size 8192 --ubatch-size 1024 -fa on ` -ngl 99 --device $device $cli_opts diff --git a/scripts/snapdragon/windows/run-completion.ps1 b/scripts/snapdragon/windows/run-completion.ps1 index 5841a82fa..ffce8184d 100644 --- a/scripts/snapdragon/windows/run-completion.ps1 +++ b/scripts/snapdragon/windows/run-completion.ps1 @@ -49,5 +49,5 @@ $env:ADSP_LIBRARY_PATH="$basedir\lib" & "$basedir\bin\llama-completion.exe" ` --no-mmap -m $basedir\..\..\gguf\$model ` --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 ` - --ctx-size 8192 --batch-size 256 -fa on ` + --ctx-size 8192 --ubatch-size 1024 -fa on ` -ngl 99 -no-cnv --device $device $cli_opts diff --git a/scripts/snapdragon/windows/run-mtmd.ps1 b/scripts/snapdragon/windows/run-mtmd.ps1 index be8178751..b38fae35f 100644 --- a/scripts/snapdragon/windows/run-mtmd.ps1 +++ b/scripts/snapdragon/windows/run-mtmd.ps1 @@ -64,5 +64,5 @@ $env:ADSP_LIBRARY_PATH="$basedir\lib" --mmproj $basedir\..\..\gguf\$mmproj ` --image $basedir\..\..\gguf\$image ` --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 ` - --ctx-size 8192 --ubatch-size 256 -fa on ` + --ctx-size 8192 --ubatch-size 1024 -fa on ` -ngl 99 --device $device -v $cli_opts