hexagon: optimize HMX matmul operations (#21071)

* optimize hmx_mat_mul functions by calculating row and column tiles upfront

* refactor core_dot_chunk_fp16 to use size_t for tile counts and improve readability

* wip

* set scale outside of loop

* wip

* refactor core_mma_chunk_fp16 and mat_mul_qk_0_d16a32 to use size_t for tile counts

* wip

* wip

* refactor transfer_output_chunk_fp16_to_fp32 to use size_t for dimensions

* refactor core_dot_chunk_fp16 to use size_t for tile row stride calculation

* wip

* refactor hmx_mat_mul functions to use hvx_vec_splat_f16 for column scales initialization

* refactor hmx_mat_mul_permuted_w16a32_batched to streamline scale setting and locking

* refactor core_dot_chunk_fp16 to improve tile stride calculations for output

* refactor hmx_mat_mul functions to use Q6_V_vsplat_R for column scales initialization

* fix compiling error

* wip

* optimize row and column tile indexing in core_mma_chunk_fp16 function

* wip

* Revert "wip"

This reverts commit cde679eff79c4a28dd2d89d32f710015e09592b6.

* Add size limit check for HAP_mmap in htp_iface_mmap and drop_mmap functions

* wip
This commit is contained in:
nullname 2026-04-17 04:48:34 +08:00 committed by GitHub
parent 4fbdabdc61
commit 85dde8dc4a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 80 additions and 49 deletions

View file

@ -648,9 +648,9 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
assert(n_cols % HMX_FP16_TILE_N_COLS == 0);
assert(k_block % HMX_FP16_TILE_N_COLS == 0);
int n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS;
int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS;
int n_tot_tiles = n_col_tiles * n_k_tiles;
size_t n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS;
size_t n_k_tiles = k_block / HMX_FP16_TILE_N_COLS;
size_t n_tot_tiles = n_col_tiles * n_k_tiles;
size_t n_tiles_per_task = hmx_ceil_div(n_tot_tiles, ctx->n_threads);
@ -678,9 +678,8 @@ static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict
__builtin_assume(n_dot_tiles > 0);
Q6_bias_mxmem2_A((void *)scales);
for (int r = 0; r < n_row_tiles; ++r) {
for (int c = 0; c < n_col_tiles; ++c) {
for (size_t c = 0; c < n_col_tiles; ++c) {
Q6_mxclracc_hf();
const __fp16 *row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS;
@ -738,25 +737,25 @@ static inline void hmx_matmul_job_init(hmx_matmul_job_t * job,
static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, int n_rows, int n_cols, int n) {
assert(n_cols % HMX_FP16_TILE_N_COLS == 0);
const int n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS;
const size_t tile_row_stride = (n_cols / HMX_FP16_TILE_N_COLS) * HMX_FP16_TILE_N_ELMS;
const HVX_Vector one = hvx_vec_splat_f16(1.0);
for (int r = 0; r < n_rows; r += 2) {
int r0 = r / HMX_FP16_TILE_N_ROWS;
int r1 = r % HMX_FP16_TILE_N_ROWS;
for (size_t r = 0; r < n_rows; r += 2) {
const size_t r0 = r / HMX_FP16_TILE_N_ROWS;
const size_t r1 = (r % HMX_FP16_TILE_N_ROWS) / 2; // index of the row pair within the tile
const __fp16 *row_base = vtcm_src + r0 * tile_row_stride;
float *output_row_base = dst + r * n; // global memory row base for row r (and r+1)
#pragma unroll(4)
for (int c = 0; c < n_cols; c += HMX_FP16_TILE_N_COLS) {
int c0 = c / HMX_FP16_TILE_N_COLS;
const __fp16 *tile = vtcm_src + (r0 * n_col_tiles + c0) * HMX_FP16_TILE_N_ELMS;
HVX_Vector v = ((const HVX_Vector *) tile)[r1 / 2];
for (size_t c = 0; c < n_cols; c += HMX_FP16_TILE_N_COLS) {
const size_t c0 = c / HMX_FP16_TILE_N_COLS;
const __fp16 *tile = row_base + c0 * HMX_FP16_TILE_N_ELMS;
HVX_Vector v = ((const HVX_Vector *) tile)[r1];
HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one);
volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (dst + (r * n + c + 0));
volatile HVX_Vector *pv_out1 = (volatile HVX_Vector *) (dst + (r * n + c + n)); // next row in global memory
volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (output_row_base + c + 0);
volatile HVX_Vector *pv_out1 = (volatile HVX_Vector *) (output_row_base + c + n); // next row in global memory
*pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp));
if (r + 1 < n_rows) {
@ -794,7 +793,7 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst,
assert(n_cols % HMX_FP16_TILE_N_COLS == 0);
size_t n_tot_chunks = n_rows;
size_t n_chunks_per_task = 32; // must be multiple of HMX_FP16_TILE_N_ROWS (32)
size_t n_chunks_per_task = HMX_FP16_TILE_N_ROWS; // must be multiple of HMX_FP16_TILE_N_ROWS (32)
output_transfer_task_state_t state;
state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task;
@ -926,7 +925,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params);
}
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16
FARF(MEDIUM, "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu",
__func__, params->m, params->k, params->n, group_size, params->ne13,
@ -944,12 +943,15 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16);
const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16);
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
for (int b3 = 0; b3 < params->ne13; ++b3) {
for (int b2_base = 0; b2_base < params->ne12; b2_base += group_size) {
const __fp16 *weight_group = hmx_matmul_weight_batch_ptr(params, b2_base, b3);
for (size_t mr = 0; mr < (size_t) params->m; mr += m_chunk_n_rows) {
const size_t n_rows = hex_smin((size_t) params->m - mr, m_chunk_n_rows);
const size_t n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS);
// Pre-load activations for all heads in the group (once per m_chunk).
// When the source is strided (permuted Q), use 2D DMA to gather
@ -987,10 +989,9 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first);
}
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
for (size_t nc = 0; nc < (size_t) params->n; nc += n_chunk_n_cols) {
const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols);
const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS);
TIMER_START(weight_load);
{
@ -1014,11 +1015,9 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
for (int g = 0; g < group_size; ++g) {
TIMER_START(hmx_core);
{
const __fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride;
const int n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS);
const int n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS);
core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales,
n_row_tiles, n_col_tiles, params->k / 32);
const __fp16 * vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride;
core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles,
params->k / 32);
}
TIMER_STOP(hmx_core);
@ -1030,12 +1029,12 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
TIMER_STOP(output_store);
}
}
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
}
}
}
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
TIMER_STOP(total);
#if defined(ENABLE_PROFILE_TIMERS)
@ -1103,7 +1102,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
return -1;
}
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16
FARF(MEDIUM, "%s: m=%d k=%d n=%d mc=%zu nc=%zu vtcm=%zu/%zu",
__func__, m, k, n, m_chunk_n_rows, n_chunk_n_cols,
@ -1121,7 +1120,8 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
// transfer activation matrix chunk into VTCM
size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS);
TIMER_START(activation_load);
{
@ -1159,7 +1159,8 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
}
for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) {
size_t n_cols = hex_smin(n - nc, n_chunk_n_cols);
const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols);
const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS);
TIMER_START(weight_load);
{
@ -1184,8 +1185,6 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
TIMER_START(hmx_core);
{
const int n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS);
const int n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS);
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32);
}
TIMER_STOP(hmx_core);
@ -1307,7 +1306,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
return -1;
}
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16
FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu",
__func__, m, k, n, weight_type, use_pipeline,
@ -1330,7 +1329,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
// transfer activation matrix chunk into VTCM
size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS);
TIMER_START(activation_load);
{
@ -1348,7 +1348,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
}
for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) {
size_t n_cols = hex_smin(n - nc, n_chunk_n_cols);
const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols);
const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS);
TIMER_START(weight_load);
{
@ -1373,8 +1374,6 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
TIMER_START(hmx_core);
{
const int n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS);
const int n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS);
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32);
}
TIMER_STOP(hmx_core);
@ -1521,14 +1520,16 @@ void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __f
Q6_bias_mxmem2_A((void *)col_scales);
for (int i = 0; i < n_row_tiles; ++i) {
for (int j = 0; j < n_col_tiles; ++j) {
const size_t dot_tile_stride = n_dot_tiles * HMX_FP16_TILE_N_ELMS;
for (size_t i = 0; i < n_row_tiles; ++i) {
const __fp16 *row_base = a + i * dot_tile_stride;
__fp16 *res_base = c + i * n_col_tiles * HMX_FP16_TILE_N_ELMS;
for (size_t j = 0; j < n_col_tiles; ++j) {
Q6_mxclracc_hf();
const __fp16 *row_tiles = a + i * n_dot_tiles * HMX_FP16_TILE_N_ELMS;
const __fp16 *col_tiles = b + j * n_dot_tiles * HMX_FP16_TILE_N_ELMS;
__fp16 *accum_tile = c + (i * n_col_tiles + j) * HMX_FP16_TILE_N_ELMS;
const __fp16 *col_tiles = b + j * dot_tile_stride;
const __fp16 *row_tiles = row_base;
__fp16 *accum_tile = res_base + j * HMX_FP16_TILE_N_ELMS;
if (!zero_init) {
Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047);
Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047);
@ -1697,7 +1698,7 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
v = Q6_V_vror_VR(v, VLEN - 8);
}
}
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16
TIMER_DEFINE(fetch);
TIMER_DEFINE(act_load);
@ -1715,7 +1716,7 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
const int n_col_tiles = hmx_ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS);
for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) {
size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE);
const size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE);
TIMER_START(fetch);
// fetch activation block into VTCM
@ -1731,13 +1732,13 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
}
// fetch weight block into VTCM (x4x2 sub-block: quants + scales)
const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz);
{
qweight_fetch_task_state_t s;
const int blk_start = kk / QK_Q4_0x4x2;
const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2;
const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2);
const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz);
const int scale_blk_size =
(weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE;
@ -1777,7 +1778,6 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
dma_queue_pop(ctx->dma[0]);
// vtcm_scratch0 is used to store the qweight chunk
// worker_pool_run_func already returned, so fetch is done
const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz);
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, vtcm_scratch0,
n_blk_sz, k_blk_sz, sub_row_stride, weight_type);
}

View file

@ -98,6 +98,8 @@ enum htp_op_code {
#define HTP_OP_MAX_VMEM (3221225472u)
#endif
#define HTP_MMAP_MAX_VMEM (2147483648u)
enum htp_tensor_flags {
HTP_TENSOR_COMPUTE = (1U << 0), // Tensor buffer temporal compute data (not weights)
HTP_TENSOR_FLUSHED = (1U << 1) // Tensor buffer has been flushed (set by the NPU)

View file

@ -118,7 +118,11 @@ AEEResult htp_iface_close(remote_handle64 handle) {
// release the mmaps (if any)
for (uint32_t i=0; i<HTP_MAX_MMAPS; i++) {
if (ctx->mmap[i].size) {
#if __HVX_ARCH__ > 73
HAP_munmap2((void *) ctx->mmap[i].base, ctx->mmap[i].size);
#else
HAP_munmap((void *) ctx->mmap[i].base, ctx->mmap[i].size);
#endif
ctx->mmap[i].size = 0;
ctx->mmap[i].base = NULL;
ctx->mmap[i].fd = -1;
@ -173,8 +177,16 @@ AEEResult htp_iface_mmap(remote_handle64 handle, int fd, uint32_t size, uint32_t
struct htp_mmap *m = &ctx->mmap[i];
if (!m->size) {
FARF(HIGH, "mmap : fd %u size %u pinned %u", fd, size, pinned);
#if __HVX_ARCH__ > 73
void *va = HAP_mmap2(NULL, size, HAP_PROT_READ | HAP_PROT_WRITE, 0, fd, 0);
#else
if (size > HTP_MMAP_MAX_VMEM) { // HAP_mmap has a size limit of 2GB
FARF(ERROR, "mmap failed : size %u exceeds 2GB limit for HAP_mmap", (uint32_t) size);
abort(); // can't do much else at this point
}
void *va = HAP_mmap(NULL, size, HAP_PROT_READ | HAP_PROT_WRITE, 0, fd, 0);
#endif
if (va == (void*)-1) {
FARF(ERROR, "mmap failed : va %p fd %u size %u", va, fd, (uint32_t) size);
return AEE_EFAILED;
@ -202,7 +214,11 @@ AEEResult htp_iface_munmap(remote_handle64 handle, int fd) {
struct htp_mmap *m = &ctx->mmap[i];
if (fd < 0 || m->fd == fd) {
FARF(HIGH, "unmmap : base %p fd %u size %u", (void*) m->base, m->fd, (uint32_t) m->size);
#if __HVX_ARCH__ > 73
HAP_munmap2((void *) m->base, m->size);
#else
HAP_munmap((void *) m->base, m->size);
#endif
m->size = 0;
m->base = NULL;
m->fd = -1;
@ -526,7 +542,11 @@ static inline bool reuse_buf(struct htp_context *ctx, uint32_t *m_reuse, struct
static inline void drop_mmap(struct htp_context *ctx, struct htp_mmap *m) {
if (m->size && !m->pinned) {
FARF(HIGH, "unmap : fd %u base %p size %u pinned %u", m->fd, (void*) m->base, (uint32_t) m->size, m->pinned);
#if __HVX_ARCH__ > 73
HAP_munmap2((void *) m->base, m->size);
#else
HAP_munmap((void *) m->base, m->size);
#endif
m->size = 0;
m->base = 0;
m->fd = -1;
@ -540,7 +560,16 @@ static inline void mmap_buf(struct htp_context *ctx, struct htp_buf_desc *b) {
for (uint32_t i=0; i < HTP_MAX_MMAPS; i++) {
struct htp_mmap *m = &ctx->mmap[i];
if (!m->size) {
#if __HVX_ARCH__ > 73
void *va = HAP_mmap2(NULL, b->size, HAP_PROT_READ | HAP_PROT_WRITE, 0, b->fd, 0);
#else
if (b->size > HTP_MMAP_MAX_VMEM) { // HAP_mmap has a size limit of 2GB
FARF(ERROR, "mmap failed : size %u exceeds 2GB limit for HAP_mmap", (uint32_t) b->size);
abort(); // can't do much else at this point
}
void *va = HAP_mmap(NULL, b->size, HAP_PROT_READ | HAP_PROT_WRITE, 0, b->fd, 0);
#endif
if (va == (void*)-1) {
FARF(ERROR, "mmap failed : va %p fd %u size %u", va, b->fd, (uint32_t) b->size);
abort(); // can't do much else at this point