mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-17 04:09:19 +00:00
Hexagon: Process M-tail rows on HMX instead of HVX (#22724)
* hex-mm: process m-tail rows on HMX instead of HVX * hmx-mm: unroll and optimize padded activation loop --------- Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
This commit is contained in:
parent
ff806a110d
commit
bbeb89d76c
2 changed files with 48 additions and 39 deletions
|
|
@ -742,17 +742,45 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst,
|
|||
// activations : fp32 -> fp16
|
||||
|
||||
static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows, int k_block, int k_stride) {
|
||||
for (int r = 0; r < n_rows; r += 2) {
|
||||
const int n_rows_padded = hex_align_up(n_rows, HMX_FP16_TILE_N_ROWS);
|
||||
const int n_rows_tiled = (n_rows / HMX_FP16_TILE_N_ROWS) * HMX_FP16_TILE_N_ROWS;
|
||||
|
||||
int r = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (r = 0; r < n_rows_tiled; r += 2) {
|
||||
int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index
|
||||
int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx
|
||||
|
||||
const bool next_row_valid = (r + 1) < n_rows;
|
||||
|
||||
const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride);
|
||||
const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride);
|
||||
for (int c = 0; c < k_block; c += 32) {
|
||||
HVX_Vector v0 = *pv_in0++;
|
||||
HVX_Vector v1 = next_row_valid ? *pv_in1++ : Q6_V_vzero();
|
||||
HVX_Vector v1 = *pv_in1++;
|
||||
|
||||
HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1);
|
||||
|
||||
// compute output position
|
||||
int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index
|
||||
int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0;
|
||||
|
||||
HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS);
|
||||
tile[r1 / 2] = v_out;
|
||||
}
|
||||
}
|
||||
|
||||
for (; r < n_rows_padded; r += 2) {
|
||||
int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index
|
||||
int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx
|
||||
|
||||
const bool row0_valid = r < n_rows;
|
||||
const bool row1_valid = (r + 1) < n_rows;
|
||||
|
||||
const HVX_Vector *pv_in0 = row0_valid ? (const HVX_Vector *) (src + (r + 0) * k_stride) : NULL;
|
||||
const HVX_Vector *pv_in1 = row1_valid ? (const HVX_Vector *) (src + (r + 1) * k_stride) : NULL;
|
||||
for (int c = 0; c < k_block; c += 32) {
|
||||
HVX_Vector v0 = row0_valid ? *pv_in0++ : Q6_V_vzero();
|
||||
HVX_Vector v1 = row1_valid ? *pv_in1++ : Q6_V_vzero();
|
||||
|
||||
HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1);
|
||||
|
||||
|
|
@ -889,7 +917,9 @@ static __attribute__((noinline)) int mat_mul_qk_0_d16a32_out_stationary(struct h
|
|||
// n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper).
|
||||
const size_t m_block_cost = (size_t) n * 3;
|
||||
const size_t n_block_cost = (size_t) m * 2;
|
||||
if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, m, n, m_block_cost, n_block_cost, &M_BLOCK_SIZE,
|
||||
if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn,
|
||||
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
|
||||
m_block_cost, n_block_cost, &M_BLOCK_SIZE,
|
||||
&N_BLOCK_SIZE, &vtcm_used) != 0) {
|
||||
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);
|
||||
return -1;
|
||||
|
|
@ -1084,7 +1114,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
|
|||
|
||||
if (m >= 128) {
|
||||
size_t mc = 0, nc = 0, used = 0;
|
||||
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn, m, n,
|
||||
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn,
|
||||
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
|
||||
/*m_block_cost=*/(size_t) n * 3,
|
||||
/*n_block_cost=*/(size_t) m * 2, &mc, &nc, &used) == 0 &&
|
||||
hmx_ceil_div((size_t) n, nc) >= 2) {
|
||||
|
|
@ -1096,7 +1127,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
|
|||
}
|
||||
|
||||
if (!use_pipeline) {
|
||||
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn, m, n,
|
||||
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn,
|
||||
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
|
||||
/*m_block_cost=*/(size_t) n * 3,
|
||||
/*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
|
||||
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);
|
||||
|
|
@ -1432,7 +1464,8 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
|
|||
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256,
|
||||
/*per_n=*/3 * vec_dot_size,
|
||||
/*per_m=*/group_size * vec_dot_size + f32_scratch_per_m,
|
||||
/*per_mn=*/sizeof(__fp16), params->m, params->n,
|
||||
/*per_mn=*/sizeof(__fp16),
|
||||
hex_align_up(params->m, HMX_FP16_TILE_N_ROWS), params->n,
|
||||
/*m_block_cost=*/(size_t) params->n,
|
||||
/*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
|
||||
FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__);
|
||||
|
|
@ -1612,7 +1645,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
|
|||
/*per_n=*/3 * vec_dot_size, // W + S0 + S1
|
||||
/*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch
|
||||
/*per_mn=*/sizeof(__fp16), // O
|
||||
m, n,
|
||||
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
|
||||
/*m_block_cost=*/(size_t) n,
|
||||
/*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
|
||||
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);
|
||||
|
|
|
|||
|
|
@ -2991,12 +2991,10 @@ int op_matmul(struct htp_ops_context * octx) {
|
|||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
// M alignment: when M > 32 but not 32-aligned, we split into
|
||||
// HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows).
|
||||
// When M <= 32 and not 32-aligned, fall back entirely to HVX.
|
||||
// M alignment: Use HMX when M >= 32, the last partial tile (m_total % 32 rows)
|
||||
// is handled by HMX itself; when M < 32 fall back to HVX.
|
||||
const int m_total = (int) src1->ne[1];
|
||||
const int m_tail = m_total % 32;
|
||||
const int m_hmx = m_total - m_tail;
|
||||
const int m_hmx = m_total & ~31; // 0 when M < 32
|
||||
|
||||
if (m_hmx == 0) {
|
||||
return op_matmul_hvx(octx);
|
||||
|
|
@ -3009,7 +3007,6 @@ int op_matmul(struct htp_ops_context * octx) {
|
|||
int k = (int) src0->ne[0]; // inner dimension
|
||||
int n = (int) src0->ne[1]; // weight columns
|
||||
|
||||
// --- Phase 1: HMX on the first m_hmx (32-aligned) rows ---
|
||||
int ret = -1;
|
||||
|
||||
// Row strides in elements. For compact tensors these equal k; for
|
||||
|
|
@ -3027,7 +3024,7 @@ int op_matmul(struct htp_ops_context * octx) {
|
|||
.dst = (float *) dst->data,
|
||||
.activation = (float *) src1->data,
|
||||
.permuted_weight = (const __fp16 *) src0->data,
|
||||
.m = m_hmx,
|
||||
.m = m_total,
|
||||
.k = k,
|
||||
.n = n,
|
||||
.act_stride = act_stride,
|
||||
|
|
@ -3048,12 +3045,12 @@ int op_matmul(struct htp_ops_context * octx) {
|
|||
} else {
|
||||
ret = hmx_mat_mul_permuted_w16a32(octx->ctx,
|
||||
(float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data,
|
||||
m_hmx, k, n, act_stride, wgt_stride);
|
||||
m_total, k, n, act_stride, wgt_stride);
|
||||
}
|
||||
} else {
|
||||
ret = hmx_mat_mul_permuted_qk_0_d16a32(octx->ctx,
|
||||
(float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data,
|
||||
m_hmx, k, n, (int) src0->type);
|
||||
m_total, k, n, (int) src0->type);
|
||||
}
|
||||
|
||||
if (ret != 0) {
|
||||
|
|
@ -3061,27 +3058,6 @@ int op_matmul(struct htp_ops_context * octx) {
|
|||
return op_matmul(octx);
|
||||
}
|
||||
|
||||
// --- Phase 2: HVX on the remaining m_tail rows ---
|
||||
if (m_tail > 0) {
|
||||
// copy of src1 and dst
|
||||
struct htp_tensor src1_tail = *src1;
|
||||
struct htp_tensor dst_tail = *dst;
|
||||
|
||||
src1_tail.ne[1] = m_tail; // only tail rows
|
||||
dst_tail.ne[1] = m_tail; // only tail rows
|
||||
|
||||
// Offset activation and dst pointers past the HMX-processed rows.
|
||||
// Use nb[1] (row stride in bytes) to compute the byte offset.
|
||||
src1_tail.data += (uint32_t) m_hmx * src1->nb[1];
|
||||
dst_tail.data += (uint32_t) m_hmx * dst->nb[1];
|
||||
|
||||
octx->src[1] = &src1_tail;
|
||||
octx->dst = &dst_tail;
|
||||
|
||||
FARF(HIGH, "hmx-matmul: HVX tail m_tail %d src1 %p dst %p", m_tail, (void *) src1_tail.data, (void *) dst_tail.data);
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
return 0;
|
||||
#endif // HTP_HAS_HMX
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue