Hexagon: Process M-tail rows on HMX instead of HVX (#22724)

* hex-mm: process m-tail rows on HMX instead of HVX

* hmx-mm: unroll and optimize padded activation loop

---------

Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
This commit is contained in:
Trivikram Reddy 2026-05-05 11:43:03 -05:00 committed by GitHub
parent ff806a110d
commit bbeb89d76c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 48 additions and 39 deletions

View file

@ -742,17 +742,45 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst,
// activations : fp32 -> fp16
static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows, int k_block, int k_stride) {
for (int r = 0; r < n_rows; r += 2) {
const int n_rows_padded = hex_align_up(n_rows, HMX_FP16_TILE_N_ROWS);
const int n_rows_tiled = (n_rows / HMX_FP16_TILE_N_ROWS) * HMX_FP16_TILE_N_ROWS;
int r = 0;
#pragma unroll(2)
for (r = 0; r < n_rows_tiled; r += 2) {
int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index
int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx
const bool next_row_valid = (r + 1) < n_rows;
const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride);
const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride);
for (int c = 0; c < k_block; c += 32) {
HVX_Vector v0 = *pv_in0++;
HVX_Vector v1 = next_row_valid ? *pv_in1++ : Q6_V_vzero();
HVX_Vector v1 = *pv_in1++;
HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1);
// compute output position
int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index
int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0;
HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS);
tile[r1 / 2] = v_out;
}
}
for (; r < n_rows_padded; r += 2) {
int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index
int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx
const bool row0_valid = r < n_rows;
const bool row1_valid = (r + 1) < n_rows;
const HVX_Vector *pv_in0 = row0_valid ? (const HVX_Vector *) (src + (r + 0) * k_stride) : NULL;
const HVX_Vector *pv_in1 = row1_valid ? (const HVX_Vector *) (src + (r + 1) * k_stride) : NULL;
for (int c = 0; c < k_block; c += 32) {
HVX_Vector v0 = row0_valid ? *pv_in0++ : Q6_V_vzero();
HVX_Vector v1 = row1_valid ? *pv_in1++ : Q6_V_vzero();
HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1);
@ -889,7 +917,9 @@ static __attribute__((noinline)) int mat_mul_qk_0_d16a32_out_stationary(struct h
// n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper).
const size_t m_block_cost = (size_t) n * 3;
const size_t n_block_cost = (size_t) m * 2;
if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, m, n, m_block_cost, n_block_cost, &M_BLOCK_SIZE,
if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn,
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
m_block_cost, n_block_cost, &M_BLOCK_SIZE,
&N_BLOCK_SIZE, &vtcm_used) != 0) {
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);
return -1;
@ -1084,7 +1114,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
if (m >= 128) {
size_t mc = 0, nc = 0, used = 0;
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn, m, n,
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn,
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
/*m_block_cost=*/(size_t) n * 3,
/*n_block_cost=*/(size_t) m * 2, &mc, &nc, &used) == 0 &&
hmx_ceil_div((size_t) n, nc) >= 2) {
@ -1096,7 +1127,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
}
if (!use_pipeline) {
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn, m, n,
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn,
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
/*m_block_cost=*/(size_t) n * 3,
/*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);
@ -1432,7 +1464,8 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256,
/*per_n=*/3 * vec_dot_size,
/*per_m=*/group_size * vec_dot_size + f32_scratch_per_m,
/*per_mn=*/sizeof(__fp16), params->m, params->n,
/*per_mn=*/sizeof(__fp16),
hex_align_up(params->m, HMX_FP16_TILE_N_ROWS), params->n,
/*m_block_cost=*/(size_t) params->n,
/*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__);
@ -1612,7 +1645,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
/*per_n=*/3 * vec_dot_size, // W + S0 + S1
/*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch
/*per_mn=*/sizeof(__fp16), // O
m, n,
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
/*m_block_cost=*/(size_t) n,
/*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);

View file

@ -2991,12 +2991,10 @@ int op_matmul(struct htp_ops_context * octx) {
return op_matmul_hvx(octx);
}
// M alignment: when M > 32 but not 32-aligned, we split into
// HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows).
// When M <= 32 and not 32-aligned, fall back entirely to HVX.
// M alignment: Use HMX when M >= 32, the last partial tile (m_total % 32 rows)
// is handled by HMX itself; when M < 32 fall back to HVX.
const int m_total = (int) src1->ne[1];
const int m_tail = m_total % 32;
const int m_hmx = m_total - m_tail;
const int m_hmx = m_total & ~31; // 0 when M < 32
if (m_hmx == 0) {
return op_matmul_hvx(octx);
@ -3009,7 +3007,6 @@ int op_matmul(struct htp_ops_context * octx) {
int k = (int) src0->ne[0]; // inner dimension
int n = (int) src0->ne[1]; // weight columns
// --- Phase 1: HMX on the first m_hmx (32-aligned) rows ---
int ret = -1;
// Row strides in elements. For compact tensors these equal k; for
@ -3027,7 +3024,7 @@ int op_matmul(struct htp_ops_context * octx) {
.dst = (float *) dst->data,
.activation = (float *) src1->data,
.permuted_weight = (const __fp16 *) src0->data,
.m = m_hmx,
.m = m_total,
.k = k,
.n = n,
.act_stride = act_stride,
@ -3048,12 +3045,12 @@ int op_matmul(struct htp_ops_context * octx) {
} else {
ret = hmx_mat_mul_permuted_w16a32(octx->ctx,
(float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data,
m_hmx, k, n, act_stride, wgt_stride);
m_total, k, n, act_stride, wgt_stride);
}
} else {
ret = hmx_mat_mul_permuted_qk_0_d16a32(octx->ctx,
(float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data,
m_hmx, k, n, (int) src0->type);
m_total, k, n, (int) src0->type);
}
if (ret != 0) {
@ -3061,27 +3058,6 @@ int op_matmul(struct htp_ops_context * octx) {
return op_matmul(octx);
}
// --- Phase 2: HVX on the remaining m_tail rows ---
if (m_tail > 0) {
// copy of src1 and dst
struct htp_tensor src1_tail = *src1;
struct htp_tensor dst_tail = *dst;
src1_tail.ne[1] = m_tail; // only tail rows
dst_tail.ne[1] = m_tail; // only tail rows
// Offset activation and dst pointers past the HMX-processed rows.
// Use nb[1] (row stride in bytes) to compute the byte offset.
src1_tail.data += (uint32_t) m_hmx * src1->nb[1];
dst_tail.data += (uint32_t) m_hmx * dst->nb[1];
octx->src[1] = &src1_tail;
octx->dst = &dst_tail;
FARF(HIGH, "hmx-matmul: HVX tail m_tail %d src1 %p dst %p", m_tail, (void *) src1_tail.data, (void *) dst_tail.data);
return op_matmul_hvx(octx);
}
return 0;
#endif // HTP_HAS_HMX
}