diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 2666a78a9..9e8c9966e 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -742,17 +742,45 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, // activations : fp32 -> fp16 static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows, int k_block, int k_stride) { - for (int r = 0; r < n_rows; r += 2) { + const int n_rows_padded = hex_align_up(n_rows, HMX_FP16_TILE_N_ROWS); + const int n_rows_tiled = (n_rows / HMX_FP16_TILE_N_ROWS) * HMX_FP16_TILE_N_ROWS; + + int r = 0; + + #pragma unroll(2) + for (r = 0; r < n_rows_tiled; r += 2) { int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx - const bool next_row_valid = (r + 1) < n_rows; - const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride); const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride); for (int c = 0; c < k_block; c += 32) { HVX_Vector v0 = *pv_in0++; - HVX_Vector v1 = next_row_valid ? *pv_in1++ : Q6_V_vzero(); + HVX_Vector v1 = *pv_in1++; + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + // compute output position + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } + + for (; r < n_rows_padded; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx + + const bool row0_valid = r < n_rows; + const bool row1_valid = (r + 1) < n_rows; + + const HVX_Vector *pv_in0 = row0_valid ? (const HVX_Vector *) (src + (r + 0) * k_stride) : NULL; + const HVX_Vector *pv_in1 = row1_valid ? (const HVX_Vector *) (src + (r + 1) * k_stride) : NULL; + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = row0_valid ? *pv_in0++ : Q6_V_vzero(); + HVX_Vector v1 = row1_valid ? *pv_in1++ : Q6_V_vzero(); HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); @@ -889,7 +917,9 @@ static __attribute__((noinline)) int mat_mul_qk_0_d16a32_out_stationary(struct h // n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper). const size_t m_block_cost = (size_t) n * 3; const size_t n_block_cost = (size_t) m * 2; - if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, m, n, m_block_cost, n_block_cost, &M_BLOCK_SIZE, + if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, + hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, + m_block_cost, n_block_cost, &M_BLOCK_SIZE, &N_BLOCK_SIZE, &vtcm_used) != 0) { FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); return -1; @@ -1084,7 +1114,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds if (m >= 128) { size_t mc = 0, nc = 0, used = 0; - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn, m, n, + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn, + hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, /*m_block_cost=*/(size_t) n * 3, /*n_block_cost=*/(size_t) m * 2, &mc, &nc, &used) == 0 && hmx_ceil_div((size_t) n, nc) >= 2) { @@ -1096,7 +1127,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds } if (!use_pipeline) { - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn, m, n, + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn, + hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, /*m_block_cost=*/(size_t) n * 3, /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); @@ -1432,7 +1464,8 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, /*per_n=*/3 * vec_dot_size, /*per_m=*/group_size * vec_dot_size + f32_scratch_per_m, - /*per_mn=*/sizeof(__fp16), params->m, params->n, + /*per_mn=*/sizeof(__fp16), + hex_align_up(params->m, HMX_FP16_TILE_N_ROWS), params->n, /*m_block_cost=*/(size_t) params->n, /*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); @@ -1612,7 +1645,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co /*per_n=*/3 * vec_dot_size, // W + S0 + S1 /*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch /*per_mn=*/sizeof(__fp16), // O - m, n, + hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, /*m_block_cost=*/(size_t) n, /*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index a0c265132..2461ae617 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -2991,12 +2991,10 @@ int op_matmul(struct htp_ops_context * octx) { return op_matmul_hvx(octx); } - // M alignment: when M > 32 but not 32-aligned, we split into - // HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows). - // When M <= 32 and not 32-aligned, fall back entirely to HVX. + // M alignment: Use HMX when M >= 32, the last partial tile (m_total % 32 rows) + // is handled by HMX itself; when M < 32 fall back to HVX. const int m_total = (int) src1->ne[1]; - const int m_tail = m_total % 32; - const int m_hmx = m_total - m_tail; + const int m_hmx = m_total & ~31; // 0 when M < 32 if (m_hmx == 0) { return op_matmul_hvx(octx); @@ -3009,7 +3007,6 @@ int op_matmul(struct htp_ops_context * octx) { int k = (int) src0->ne[0]; // inner dimension int n = (int) src0->ne[1]; // weight columns - // --- Phase 1: HMX on the first m_hmx (32-aligned) rows --- int ret = -1; // Row strides in elements. For compact tensors these equal k; for @@ -3027,7 +3024,7 @@ int op_matmul(struct htp_ops_context * octx) { .dst = (float *) dst->data, .activation = (float *) src1->data, .permuted_weight = (const __fp16 *) src0->data, - .m = m_hmx, + .m = m_total, .k = k, .n = n, .act_stride = act_stride, @@ -3048,12 +3045,12 @@ int op_matmul(struct htp_ops_context * octx) { } else { ret = hmx_mat_mul_permuted_w16a32(octx->ctx, (float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data, - m_hmx, k, n, act_stride, wgt_stride); + m_total, k, n, act_stride, wgt_stride); } } else { ret = hmx_mat_mul_permuted_qk_0_d16a32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, - m_hmx, k, n, (int) src0->type); + m_total, k, n, (int) src0->type); } if (ret != 0) { @@ -3061,27 +3058,6 @@ int op_matmul(struct htp_ops_context * octx) { return op_matmul(octx); } - // --- Phase 2: HVX on the remaining m_tail rows --- - if (m_tail > 0) { - // copy of src1 and dst - struct htp_tensor src1_tail = *src1; - struct htp_tensor dst_tail = *dst; - - src1_tail.ne[1] = m_tail; // only tail rows - dst_tail.ne[1] = m_tail; // only tail rows - - // Offset activation and dst pointers past the HMX-processed rows. - // Use nb[1] (row stride in bytes) to compute the byte offset. - src1_tail.data += (uint32_t) m_hmx * src1->nb[1]; - dst_tail.data += (uint32_t) m_hmx * dst->nb[1]; - - octx->src[1] = &src1_tail; - octx->dst = &dst_tail; - - FARF(HIGH, "hmx-matmul: HVX tail m_tail %d src1 %p dst %p", m_tail, (void *) src1_tail.data, (void *) dst_tail.data); - return op_matmul_hvx(octx); - } - return 0; #endif // HTP_HAS_HMX }