@ -201,11 +201,10 @@ static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32
// Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using
// full HVX vector width. One vmemu + one vlut16 replaces 4 separate calls.
// Output: out[0..3] each hold 32 FP16 values in the first 64 bytes.
static inline void dequantize_x4x2_q4_0_x4groups_hvx (
// Output: vector_x2 each hold 32 FP16 values in the first 64 bytes.
static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx (
const uint8_t * packed_128 , bool upper_nibbles ,
const __fp16 * scales_4 , const HVX_Vector vlut_cvt ,
HVX_Vector out [ 4 ] ) {
const __fp16 * scales_4 , const HVX_Vector vlut_cvt ) {
// Load all 128 packed bytes (4 contiguous 32-byte groups)
HVX_Vector vq = hvx_vmemu ( packed_128 ) ;
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R ( 0x0F ) ;
@ -221,8 +220,7 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx(
HVX_Vector v_hi = Q6_V_hi_W ( vp ) ; // [group2: 32 fp16 | group3: 32 fp16]
// Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b
volatile HVX_Vector vscale = hvx_vmemu ( scales_4 ) ;
HVX_Vector vscale = hvx_vmemu ( scales_4 ) ;
HVX_Vector v_sc01 = hvx_vec_repl_2x_f16 ( vscale ) ;
HVX_Vector v_sc23 = hvx_vec_repl_2x_f16 ( Q6_V_vror_VR ( vscale , 4 ) ) ;
@ -230,8 +228,9 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx(
v_hi = Q6_Vhf_equals_Vqf16 ( Q6_Vqf16_vmpy_VhfVhf ( v_hi , v_sc23 ) ) ;
// Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter
out [ 0 ] = v_lo ; // group0 already in [0:63]
out [ 1 ] = v_hi ; // group2 already in [0:63]
HVX_Vector_x2 r = { v_lo , /* group1 already in [0:63] */
v_hi /* group2 already in [0:63] */ } ;
return r ;
}
// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes.
@ -292,12 +291,11 @@ static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx(const uint8_t * packed
}
// Batch-dequantize 4 contiguous x4x2 MXFP4 groups (4x32 = 128 packed bytes).
static inline void dequantize_x4x2_mxfp4_x4groups_hvx ( const uint8_t * packed_128 ,
static inline HVX_Vector_x4 dequantize_x4x2_mxfp4_x4groups_hvx ( const uint8_t * packed_128 ,
bool upper_nibbles ,
int sub_blk_base ,
const HVX_Vector vlut_cvt ,
mxfp4_scales_t scales ,
HVX_Vector out [ 4 ] ) {
mxfp4_scales_t scales ) {
HVX_Vector vq = hvx_vmemu ( packed_128 ) ;
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R ( 0x0F ) ;
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR ( vq , 4 ) : vq ;
@ -318,10 +316,8 @@ static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_12
v_lo = Q6_Vhf_equals_Vqf16 ( Q6_Vqf16_vmpy_VhfVhf ( v_lo , v_sc01 ) ) ;
v_hi = Q6_Vhf_equals_Vqf16 ( Q6_Vqf16_vmpy_VhfVhf ( v_hi , v_sc23 ) ) ;
out [ 0 ] = v_lo ;
out [ 1 ] = Q6_V_vror_VR ( v_lo , 64 ) ;
out [ 2 ] = v_hi ;
out [ 3 ] = Q6_V_vror_VR ( v_hi , 64 ) ;
HVX_Vector_x4 r = { v_lo , Q6_V_vror_VR ( v_lo , 64 ) , v_hi , Q6_V_vror_VR ( v_hi , 64 ) } ;
return r ;
}
// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16.
@ -372,18 +368,18 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1 ;
for ( int r = 0 ; r < HMX_FP16_TILE_N_ROWS ; r + = 2 , row1 + = 2 ) {
HVX_Vector v0 [ 2 ] ;
const uint8_t * r0 = vtcm_src + row_offset ; row_offset + = row_stride ;
dequantize_x4x2_q4_0_x4groups_hvx ( r0 + packed_off , upper , ( const __fp16 * ) ( r0 + scale_off ) , vlut_cvt , v0 ) ;
Q6_vscatter_RMVwV ( ( size_t ) tile_bases [ 0 ] , 2 * HMX_FP16_TILE_SIZE - 1 , v_off , v0 [ 0 ] ) ;
Q6_vscatter_RMVwV ( ( size_t ) tile_bases [ 2 ] , 2 * HMX_FP16_TILE_SIZE - 1 , v_off , v0 [ 1 ] ) ;
const uint8_t * r1 = vtcm_src + row_offset ; row_offset + = row_stride ;
HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx ( r0 + packed_off , upper , ( const __fp16 * ) ( r0 + scale_off ) , vlut_cvt ) ;
HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx ( r1 + packed_off , upper , ( const __fp16 * ) ( r1 + scale_off ) , vlut_cvt ) ;
Q6_vscatter_RMVwV ( ( size_t ) tile_bases [ 0 ] , 2 * HMX_FP16_TILE_SIZE - 1 , v_off , dv0 . v [ 0 ] ) ;
Q6_vscatter_RMVwV ( ( size_t ) tile_bases [ 2 ] , 2 * HMX_FP16_TILE_SIZE - 1 , v_off , dv0 . v [ 1 ] ) ;
v_off = Q6_Vw_vadd_VwVw ( v_off , v_scat_step ) ;
r0 = vtcm_src + row_offset ; row_offset + = row_stride ;
dequantize_x4x2_q4_0_x4groups_hvx ( r0 + packed_off , upper , ( const __fp16 * ) ( r0 + scale_off ) , vlut_cvt , v0 ) ;
Q6_vscatter_RMVwV ( ( size_t ) tile_bases [ 0 ] , 2 * HMX_FP16_TILE_SIZE - 1 , v_off , v0 [ 0 ] ) ;
Q6_vscatter_RMVwV ( ( size_t ) tile_bases [ 2 ] , 2 * HMX_FP16_TILE_SIZE - 1 , v_off , v0 [ 1 ] ) ;
Q6_vscatter_RMVwV ( ( size_t ) tile_bases [ 0 ] , 2 * HMX_FP16_TILE_SIZE - 1 , v_off , dv1 . v [ 0 ] ) ;
Q6_vscatter_RMVwV ( ( size_t ) tile_bases [ 2 ] , 2 * HMX_FP16_TILE_SIZE - 1 , v_off , dv1 . v [ 1 ] ) ;
v_off = Q6_Vw_vadd_VwVw ( v_off , v_scat_step ) ;
}
@ -415,21 +411,21 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
// Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
mxfp4_scales_t r0_e8 = mxfp4_convert_scales ( r0 + e8m0_blk_off ) ;
HVX_Vector v0[4 ] , v1[4 ] ;
dequantize_x4x2_mxfp4_x4groups_hvx ( r0 + packed_off , upper , sub_blk_base , vlut_cvt , r0_e8 , v0 );
HVX_Vector _x4 d v0 , d v1 ;
dv0 = dequantize_x4x2_mxfp4_x4groups_hvx ( r0 + packed_off , upper , sub_blk_base , vlut_cvt , r0_e8 );
if ( row1 < n_cols ) {
mxfp4_scales_t r1_e8 = mxfp4_convert_scales ( r1 + e8m0_blk_off ) ;
dequantize_x4x2_mxfp4_x4groups_hvx ( r1 + packed_off , upper , sub_blk_base , vlut_cvt , r1_e8 , v1 );
dv1 = dequantize_x4x2_mxfp4_x4groups_hvx ( r1 + packed_off , upper , sub_blk_base , vlut_cvt , r1_e8 );
} else {
v1 [0 ] = v1 [1 ] = v1 [2 ] = v1 [3 ] = Q6_V_vzero ( ) ;
d v1.v [0 ] = d v1.v [1 ] = d v1.v [2 ] = d v1.v [3 ] = Q6_V_vzero ( ) ;
}
for ( int g = 0 ; g < 4 ; g + + ) {
Q6_vscatter_QRMVwV ( q_mask64 , ( size_t ) tile_bases [ g ] , HMX_FP16_TILE_SIZE - 1 , v_off , v0 [g ] ) ;
Q6_vscatter_QRMVwV ( q_mask64 , ( size_t ) tile_bases [ g ] , HMX_FP16_TILE_SIZE - 1 , v_off , d v0.v [g ] ) ;
}
v_off = Q6_Vw_vadd_VwVw ( v_off , v_scat_step ) ;
for ( int g = 0 ; g < 4 ; g + + ) {
Q6_vscatter_QRMVwV ( q_mask64 , ( size_t ) tile_bases [ g ] , HMX_FP16_TILE_SIZE - 1 , v_off , v1 [g ] ) ;
Q6_vscatter_QRMVwV ( q_mask64 , ( size_t ) tile_bases [ g ] , HMX_FP16_TILE_SIZE - 1 , v_off , d v1.v [g ] ) ;
}
v_off = Q6_Vw_vadd_VwVw ( v_off , v_scat_step ) ;
}
@ -612,11 +608,13 @@ static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict
const __fp16 * row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS ;
const __fp16 * col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS ;
for ( int k = 0 ; k < n_dot_tiles ; + + k ) {
Q6_activation_hf_mxmem_RR ( ( unsigned int ) row_tiles , 2047 ) ;
Q6_weight_hf_mxmem_RR ( ( unsigned int ) col_tiles , 2047 ) ;
row_tiles + = HMX_FP16_TILE_N_ELMS ;
col_tiles + = HMX_FP16_TILE_N_ELMS ;
for ( int k = 0 , k_block ; k < n_dot_tiles ; k + = k_block ) {
k_block = hex_smin ( n_dot_tiles - k , 32 ) ;
const uint32_t range = 2048u * ( uint32_t ) k_block - 1 ;
Q6_activation_hf_mxmem_RR_deep ( ( unsigned int ) row_tiles , range ) ;
Q6_weight_hf_mxmem_RR ( ( unsigned int ) col_tiles , range ) ;
row_tiles + = k_block * HMX_FP16_TILE_N_ELMS ;
col_tiles + = k_block * HMX_FP16_TILE_N_ELMS ;
}
__fp16 * out_tile = output + ( r * n_col_tiles + c ) * HMX_FP16_TILE_N_ELMS ;
@ -832,10 +830,6 @@ static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *
worker_pool_run_func ( ctx - > worker_pool , transfer_activation_chunk_worker_fn , & state , ctx - > n_threads ) ;
}
//
# define FALLBACK_TO_STANDARD 1
// C += AB
static void core_mma_chunk_fp16 ( __fp16 * restrict c , const __fp16 * restrict a , const __fp16 * restrict b ,
const __fp16 * restrict col_scales , const __fp16 * restrict eye_tile ,
@ -861,314 +855,80 @@ static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, co
Q6_weight_hf_mxmem_RR ( ( unsigned int ) eye_tile , 2047 ) ;
}
for ( int k = 0 ; k < n_dot_tiles ; + + k ) {
Q6_activation_hf_mxmem_RR ( ( unsigned int ) row_tiles , 2047 ) ;
Q6_weight_hf_mxmem_RR ( ( unsigned int ) col_tiles , 2047 ) ;
row_tiles + = HMX_FP16_TILE_N_ELMS ;
col_tiles + = HMX_FP16_TILE_N_ELMS ;
for ( int k = 0 , k_block ; k < n_dot_tiles ; k + = k_block ) {
k_block = hex_smin ( n_dot_tiles - k , 32 ) ;
const uint32_t range = 2048u * ( uint32_t ) k_block - 1 ;
Q6_activation_hf_mxmem_RR_deep ( ( unsigned int ) row_tiles , range ) ;
Q6_weight_hf_mxmem_RR ( ( unsigned int ) col_tiles , range ) ;
row_tiles + = k_block * HMX_FP16_TILE_N_ELMS ;
col_tiles + = k_block * HMX_FP16_TILE_N_ELMS ;
}
Q6_mxmem_AR_after_hf ( accum_tile , 0 ) ;
}
}
}
static __attribute__ ( ( noinline ) ) int mat_mul_qk_0_d16a32_out_stationary ( struct htp_context * ctx ,
float * restrict out , const float * restrict x , const uint8_t * restrict w ,
int m , int k , int n , int weight_type ) {
// assume k % 32 == 0 && n % 32 == 0
const size_t row_stride = get_x4x2_row_stride ( weight_type , k ) ;
if ( row_stride = = 0 ) {
return - 1 ;
}
const size_t vtcm_budget = ctx - > vtcm_size ;
const size_t K_BLOCK_SIZE = 1024 ;
// Fallback: if k doesn't need K-blocking, out-stationary has no advantage
const size_t k_iters_check = ( k + K_BLOCK_SIZE - 1 ) / K_BLOCK_SIZE ;
if ( k_iters_check < = 1 ) {
FARF ( HIGH , " %s: K_BLK=%zu >= k=%d, fallback to standard path " , __func__ , K_BLOCK_SIZE , k ) ;
return FALLBACK_TO_STANDARD ;
}
// Dynamic M,N search via hmx_compute_chunks
const size_t sub_row_stride_alloc = get_x4x2_row_stride ( weight_type , K_BLOCK_SIZE ) ;
const size_t per_m = K_BLOCK_SIZE * sizeof ( float ) // scratch1: M× K× 4 (act DMA staging F32)
+ K_BLOCK_SIZE * sizeof ( __fp16 ) ; // activation: M× K× 2 (F16 tiles)
const size_t per_n = sub_row_stride_alloc // scratch0: N× sub_row(K) (packed quant)
+ K_BLOCK_SIZE * sizeof ( __fp16 ) ; // weight: N× K× 2 (F16 tiles)
const size_t per_mn = sizeof ( __fp16 ) ; // output: M× N× 2 (out-stationary)
// Alignment margin: hex_align_up can add up to 2047 bytes per buffer;
// scratch1 (mc× 6144) is naturally 2048-aligned, remaining 4 buffers need margin
const size_t align_margin = 4 * HMX_FP16_TILE_SIZE ;
const size_t overhead = HMX_FP16_TILE_SIZE + 256 + align_margin ; // eye_tile + scales + alignment
size_t M_BLOCK_SIZE , N_BLOCK_SIZE , vtcm_used ;
// Cost-based search: minimize ceil(m/mc)*m_block_cost + ceil(n/nc)*n_block_cost.
// From profiling: wt_dequant per element ≈ 1.5× activation load per element.
// m_block_cost = n*3: each extra M-block re-dequants all N× K weight (expensive).
// n_block_cost = m*2: each extra N-block re-loads all M× K activation (cheaper).
const size_t m_block_cost = ( size_t ) n * 3 ;
const size_t n_block_cost = ( size_t ) m * 2 ;
if ( hmx_compute_chunks ( vtcm_budget , overhead , per_n , per_m , per_mn ,
hex_align_up ( m , HMX_FP16_TILE_N_ROWS ) , n ,
m_block_cost , n_block_cost , & M_BLOCK_SIZE ,
& N_BLOCK_SIZE , & vtcm_used ) ! = 0 ) {
FARF ( HIGH , " %s: VTCM too small (m=%d k=%d n=%d budget=%zu) " , __func__ , m , k , n , vtcm_budget ) ;
return - 1 ;
}
// Compute precise buffer sizes from searched M,N and fixed K
const size_t weight_size = hex_align_up ( N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof ( __fp16 ) , HMX_FP16_TILE_SIZE ) ;
const size_t act_size = hex_align_up ( M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof ( __fp16 ) , HMX_FP16_TILE_SIZE ) ;
const size_t out_size = hex_align_up ( M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof ( __fp16 ) , HMX_FP16_TILE_SIZE ) ;
const size_t scratch0_sz = hex_align_up ( N_BLOCK_SIZE * sub_row_stride_alloc , HMX_FP16_TILE_SIZE ) ;
const size_t scratch1_sz = hex_align_up ( M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof ( float ) , HMX_FP16_TILE_SIZE ) ;
const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256 ;
if ( total_vtcm > vtcm_budget ) {
FARF ( HIGH , " %s: VTCM overflow after search: need %zu have %zu (M=%zu N=%zu K=%zu) " , __func__ , total_vtcm ,
vtcm_budget , M_BLOCK_SIZE , N_BLOCK_SIZE , K_BLOCK_SIZE ) ;
return - 1 ;
}
uint8_t * vtcm_ptr = ( uint8_t * ) ctx - > vtcm_base ;
__fp16 * vtcm_weight = ( __fp16 * ) vtcm_seq_alloc ( & vtcm_ptr , weight_size ) ;
__fp16 * vtcm_activation = ( __fp16 * ) vtcm_seq_alloc ( & vtcm_ptr , act_size ) ;
__fp16 * vtcm_output = ( __fp16 * ) vtcm_seq_alloc ( & vtcm_ptr , out_size ) ;
uint8_t * vtcm_scratch0 = vtcm_seq_alloc ( & vtcm_ptr , scratch0_sz ) ;
uint8_t * vtcm_scratch1 = vtcm_seq_alloc ( & vtcm_ptr , scratch1_sz ) ;
__fp16 * vtcm_eye_tile = ( __fp16 * ) vtcm_seq_alloc ( & vtcm_ptr , HMX_FP16_TILE_SIZE ) ;
__fp16 * vtcm_scales = ( __fp16 * ) vtcm_seq_alloc ( & vtcm_ptr , 256 ) ;
assert ( ( size_t ) ( vtcm_ptr - ( uint8_t * ) ctx - > vtcm_base ) < = vtcm_budget ) ;
FARF ( HIGH , " hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu " , m , k , n , weight_type ,
M_BLOCK_SIZE , N_BLOCK_SIZE , K_BLOCK_SIZE , ( size_t ) ( vtcm_ptr - ( uint8_t * ) ctx - > vtcm_base ) , vtcm_budget ) ;
// initialize eye tile (32x32 identity matrix)
{
HVX_Vector v ;
v = Q6_V_vzero ( ) ;
v = Q6_Vw_vinsert_VwR ( v , 0x3c000000 ) ;
v = Q6_V_vror_VR ( v , VLEN - 4 ) ;
v = Q6_Vw_vinsert_VwR ( v , 0x00003c00 ) ;
for ( int i = 0 ; i < 16 ; + + i ) {
( ( HVX_Vector * ) vtcm_eye_tile ) [ i ] = v ;
v = Q6_V_vror_VR ( v , VLEN - 8 ) ;
}
}
hmx_init_column_scales ( vtcm_scales , Q6_V_vsplat_R ( 0x3c00 ) ) ; // scale: 1.0, bias: 0.0 in FP16
TIMER_DEFINE ( fetch ) ;
TIMER_DEFINE ( act_load ) ;
TIMER_DEFINE ( wt_dequant ) ;
TIMER_DEFINE ( core ) ;
HAP_compute_res_hmx_lock ( ctx - > vtcm_rctx ) ;
for ( size_t mr = 0 ; mr < m ; mr + = M_BLOCK_SIZE ) {
size_t m_blk_sz = hex_smin ( m - mr , M_BLOCK_SIZE ) ;
for ( size_t nc = 0 ; nc < n ; nc + = N_BLOCK_SIZE ) {
size_t n_blk_sz = hex_smin ( n - nc , N_BLOCK_SIZE ) ;
const int n_row_tiles = hmx_ceil_div ( m_blk_sz , HMX_FP16_TILE_N_ROWS ) ;
const int n_col_tiles = hmx_ceil_div ( n_blk_sz , HMX_FP16_TILE_N_COLS ) ;
for ( size_t kk = 0 ; kk < k ; kk + = K_BLOCK_SIZE ) {
const size_t k_blk_sz = hex_smin ( k - kk , K_BLOCK_SIZE ) ;
TIMER_START ( fetch ) ;
// fetch activation block into VTCM
{
const float * activation_block = x + mr * k + kk ;
dma_queue_push ( ctx - > dma [ 0 ] ,
dma_make_ptr ( vtcm_scratch1 , activation_block ) ,
k_blk_sz * sizeof ( float ) ,
k * sizeof ( float ) ,
k_blk_sz * sizeof ( float ) ,
m_blk_sz ) ;
}
// fetch weight block into VTCM (x4x2 sub-block: quants + scales)
const size_t sub_row_stride = get_x4x2_row_stride ( weight_type , k_blk_sz ) ;
{
const int blk_start = kk / QK_Q4_0x4x2 ;
const int nb_sub = ( k_blk_sz + QK_Q4_0x4x2 - 1 ) / QK_Q4_0x4x2 ;
const int full_qrow = ( weight_type = = HTP_TYPE_Q8_0 ) ? k : ( k / 2 ) ;
const int scale_blk_size = ( weight_type = = HTP_TYPE_MXFP4 ) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE ;
uint8_t * dst = vtcm_scratch0 ;
const uint8_t * src = w + nc * row_stride ;
const size_t n_rows = n_blk_sz ;
const size_t src_stride = row_stride ;
const size_t dst_stride = sub_row_stride ;
const size_t quant_off = ( weight_type = = HTP_TYPE_Q8_0 ) ? ( blk_start * QK_Q8_0x4x2 ) : ( blk_start * ( QK_Q4_0x4x2 / 2 ) ) ;
const size_t quant_width = ( weight_type = = HTP_TYPE_Q8_0 ) ? ( nb_sub * QK_Q8_0x4x2 ) : ( nb_sub * ( QK_Q4_0x4x2 / 2 ) ) ;
const size_t scale_off = full_qrow + blk_start * scale_blk_size ;
const size_t scale_width = nb_sub * scale_blk_size ;
// 2D DMA: quants sub-range
dma_queue_push ( ctx - > dma [ 0 ] , dma_make_ptr ( dst , src + quant_off ) , dst_stride , src_stride , quant_width , n_rows ) ;
// 2D DMA: scales sub-range
dma_queue_push ( ctx - > dma [ 0 ] , dma_make_ptr ( dst + quant_width , src + scale_off ) , dst_stride , src_stride , scale_width , n_rows ) ;
}
TIMER_STOP ( fetch ) ;
TIMER_START ( act_load ) ;
// load activation block
{
dma_queue_pop ( ctx - > dma [ 0 ] ) ; // wait for act DNA
transfer_activation_chunk_threaded ( ctx , vtcm_activation , ( float * ) vtcm_scratch1 , m_blk_sz , k_blk_sz , k_blk_sz ) ;
}
TIMER_STOP ( act_load ) ;
TIMER_START ( wt_dequant ) ;
// dequantize weight block
{
dma_queue_pop ( ctx - > dma [ 0 ] ) ;
dma_queue_pop ( ctx - > dma [ 0 ] ) ;
// vtcm_scratch0 is used to store the qweight chunk
// worker_pool_run_func already returned, so fetch is done
dequantize_x4x2_weight_chunk_to_fp16_tiles ( ctx , vtcm_weight , vtcm_scratch0 ,
n_blk_sz , k_blk_sz , sub_row_stride , weight_type ) ;
}
TIMER_STOP ( wt_dequant ) ;
// core mma
TIMER_START ( core ) ;
{
core_mma_chunk_fp16 ( vtcm_output , vtcm_activation , vtcm_weight , vtcm_scales , vtcm_eye_tile , n_row_tiles ,
n_col_tiles , k_blk_sz / HMX_FP16_TILE_N_COLS , kk = = 0 ) ;
}
TIMER_STOP ( core ) ;
}
// store output block
{
float * output_block = out + ( mr * n + nc ) ;
transfer_output_chunk_threaded ( ctx , output_block , vtcm_output , m_blk_sz , n_blk_sz , n ) ;
}
}
}
HAP_compute_res_hmx_unlock ( ctx - > vtcm_rctx ) ;
# if defined(ENABLE_PROFILE_TIMERS)
FARF ( HIGH , " fetch: %lld us, act_load: %lld us, wt_dequant: %lld us, core: %lld us " ,
TIMER_US ( fetch ) , TIMER_US ( act_load ) , TIMER_US ( wt_dequant ) , TIMER_US ( core ) ) ;
# endif
return 0 ;
}
int hmx_mat_mul_permuted_qk_0_d16a32 ( struct htp_context * ctx , float * restrict dst , const float * restrict activation ,
int hmx_matmul_q_f32 ( struct htp_context * ctx , float * restrict dst , const float * restrict activation ,
const uint8_t * restrict permuted_weight , int m , int k , int n ,
int weight_type ) {
if ( ! dst | | ! activation | | ! permuted_weight | | ! m | | ! n | | ! k ) { return - 1 ; }
if ( k % 32 ! = 0 | | n % 32 ! = 0 ) { return - 1 ; }
if ( ! hex_is_aligned ( dst , VLEN ) | | ! hex_is_aligned ( activation , VLEN ) | | ! hex_is_aligned ( permuted_weight , VLEN ) ) {
return - 1 ;
}
// for large m, k (e.g. prefill FFN Down), use out-stationary version
if ( m > = 128 & & k > n & & n > 1024 ) {
int rc = mat_mul_qk_0_d16a32_out_stationary ( ctx , dst , activation , permuted_weight , m , k , n , weight_type ) ;
if ( rc ! = FALLBACK_TO_STANDARD ) {
return rc ; // 0 success, -1 error
}
FARF ( HIGH , " hmx_matmul_qk: out-stationary fallback to standard m=%d k=%d n=%d " , m , k , n ) ;
// fall through to standard path
}
size_t row_stride = get_x4x2_row_stride ( weight_type , k ) ;
if ( row_stride = = 0 ) {
return - 1 ;
}
FARF ( HIGH , " hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d " , m , k , n , weight_type ) ;
// --- Dynamic VTCM layout ---
const size_t vtcm_budget = ctx - > vtcm_size ;
const size_t vec_dot_size = k * sizeof ( __fp16 ) ;
const size_t vec_dot_size = k * sizeof ( __fp16 ) ;
const size_t vtcm_budget = ctx - > vtcm_size ;
size_t vtcm_used = 0 ;
// Pipeline = 4-stage DMA→dequant→HMX→store with HMX worker overlap.
// Only pays off when the chunker yields >=2 n-chunks, so the main loop can
// overlap HMX (C) with HVX (B/D); with a single n-chunk the extra VTCM for
// double-buffered output and the worker-dispatch overhead are pure loss.
// Try pipeline costs first; fall back to sequential if the layout collapses
// to one n-chunk. m >= 128 floor keeps HMX utilization reasonable.
const size_t pipe_per_n = row_stride + 2 * vec_dot_size ; // Q + S0 + S1 (dequant bufs)
const size_t pipe_per_mn = 2 * sizeof ( __fp16 ) ; // O x 2 (output double buffer)
const size_t seq_per_n = vec_dot_size + 2 * row_stride ; // W + S0 + S1 (x4x2 DMA bufs)
const size_t seq_per_mn = sizeof ( __fp16 ) ; // O x 1
const size_t size_per_n = row_stride + 2 * vec_dot_size ; // Q + S0 + S1 (dequant bufs)
const size_t size_per_mn = 2 * sizeof ( __fp16 ) ; // O x 2 (output double buffer)
size_t m_chunk_n_rows = 0 , n_chunk_n_cols = 0 , vtcm_used = 0 ;
bool use_pipeline = false ;
if ( m > = 128 ) {
size_t mc = 0 , nc = 0 , used = 0 ;
if ( hmx_compute_chunks ( vtcm_budget , /*overhead=*/ 256 , pipe_per_n , /*per_m=*/ vec_dot_size , pipe_per_mn ,
hex_align_up ( m , HMX_FP16_TILE_N_ROWS ) , n ,
/*m_block_cost=*/ ( size_t ) n * 3 ,
/*n_block_cost=*/ ( size_t ) m * 2 , & mc , & nc , & used ) = = 0 & &
hmx_ceil_div ( ( size_t ) n , nc ) > = 2 ) {
m_chunk_n_rows = mc ;
n_chunk_n_cols = nc ;
vtcm_used = used ;
use_pipeline = true ;
}
size_t m_chunk_n_rows = 0 , n_chunk_n_cols = 0 ;
if ( hmx_compute_chunks ( vtcm_budget , /*overhead=*/ 256 , size_per_n , /*per_m=*/ vec_dot_size , size_per_mn ,
hex_align_up ( m , HMX_FP16_TILE_N_ROWS ) , n ,
/*m_block_cost=*/ ( size_t ) n * 3 ,
/*n_block_cost=*/ ( size_t ) m * 2 , & m_chunk_n_rows , & n_chunk_n_cols , & vtcm_used ) ) {
FARF ( HIGH , " hmx-mm-q: VTCM too small : m %d k %d n %d budget %zu " , m , k , n , vtcm_budget ) ;
return - 1 ;
}
if ( ! use_pipeline ) {
if ( hmx_compute_chunks ( vtcm_budget , /*overhead=*/ 256 , seq_per_n , /*per_m=*/ vec_dot_size , seq_per_mn ,
hex_align_up ( m , HMX_FP16_TILE_N_ROWS ) , n ,
/*m_block_cost=*/ ( size_t ) n * 3 ,
/*n_block_cost=*/ ( size_t ) m * 2 , & m_chunk_n_rows , & n_chunk_n_cols , & vtcm_used ) ! = 0 ) {
FARF ( HIGH , " %s: VTCM too small (m=%d k=%d n=%d budget=%zu) " , __func__ , m , k , n , vtcm_budget ) ;
return - 1 ;
}
}
// Compute precise buffer sizes per execution path
const size_t weight_area_size = hex_align_up (
n_chunk_n_cols * ( use_pipeline ? row_stride : vec_dot_size ) , HMX_FP16_TILE_SIZE ) ;
const size_t activation_area_size = hex_align_up ( m_chunk_n_rows * vec_dot_size , HMX_FP16_TILE_SIZE ) ;
const size_t output_area_size = hex_align_up (
m_chunk_n_rows * n_chunk_n_cols * sizeof ( __fp16 ) , HMX_FP16_TILE_SIZE ) ;
const size_t weight_area_size = hex_align_up ( n_chunk_n_cols * row_stride , HMX_FP16_TILE_SIZE ) ;
const size_t act_area_size = hex_align_up ( m_chunk_n_rows * vec_dot_size , HMX_FP16_TILE_SIZE ) ;
const size_t output_area_size = hex_align_up ( m_chunk_n_rows * n_chunk_n_cols * sizeof ( __fp16 ) , HMX_FP16_TILE_SIZE ) ;
size_t scratch0_size , scratch1_size , scratch2_size ;
if ( use_pipeline ) {
scratch0_size = hex_align_up ( n_chunk_n_cols * vec_dot_size , HMX_FP16_TILE_SIZE ) ; // dequant buf 0
scratch1_size = scratch0_size ; // dequant buf 1
scratch2_size = output_area_size ; // output buf 1
} else {
scratch0_size = hex_align_up ( n_chunk_n_cols * row_stride , HMX_FP16_TILE_SIZE ) ; // x4x2 DMA buf 0
scratch1_size = scratch0_size ; // x4x2 DMA buf 1
scratch2_size = 0 ; // unused
}
scratch0_size = hex_align_up ( n_chunk_n_cols * vec_dot_size , HMX_FP16_TILE_SIZE ) ; // dequant buf 0
scratch1_size = scratch0_size ; // dequant buf 1
scratch2_size = output_area_size ; // output buf 1
uint8_t * vtcm_ptr = ( uint8_t * ) ctx - > vtcm_base ;
__fp16 * vtcm_weight = ( __fp16 * ) vtcm_seq_alloc ( & vtcm_ptr , weight_area_size ) ;
__fp16 * vtcm_activation = ( __fp16 * ) vtcm_seq_alloc ( & vtcm_ptr , act ivation _area_size) ;
__fp16 * vtcm_activation = ( __fp16 * ) vtcm_seq_alloc ( & vtcm_ptr , act_area_size ) ;
__fp16 * vtcm_output = ( __fp16 * ) vtcm_seq_alloc ( & vtcm_ptr , output_area_size ) ;
void * vtcm_scratch0 = vtcm_seq_alloc ( & vtcm_ptr , scratch0_size ) ;
void * vtcm_scratch1 = vtcm_seq_alloc ( & vtcm_ptr , scratch1_size ) ;
void * vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc ( & vtcm_ptr , scratch2_size ) : NULL ;
__fp16 * vtcm_scales = ( __fp16 * ) vtcm_seq_alloc ( & vtcm_ptr , 256 ) ;
if ( ( size_t ) ( vtcm_ptr - ( uint8_t * ) ctx - > vtcm_base ) > vtcm_budget ) {
FARF ( ERROR , " %s: vtcm overflow: used=%zu limit=%zu " , __func__ ,
( size_t ) ( vtcm_ptr - ( uint8_t * ) ctx - > vtcm_base ) , vtcm_budget ) ;
vtcm_used = vtcm_ptr - ( uint8_t * ) ctx - > vtcm_base ;
if ( vtcm_used > vtcm_budget ) {
FARF ( ERROR , " hmx-mm-q: VTCM overflow: used %zu budget %zu " , vtcm_used , vtcm_budget ) ;
return - 1 ;
}
hmx_init_column_scales ( vtcm_scales , Q6_V_vsplat_R ( 0x3c00 ) ) ; // scale: 1.0, bias: 0.0 in FP16
FARF ( HIGH , " %s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu " ,
__func__ , m , k , n , weight_type , use_pipeline ,
m_chunk_n_rows , n_chunk_n_cols ,
( size_t ) ( vtcm_ptr - ( uint8_t * ) ctx - > vtcm_base ) , vtcm_budget ) ;
FARF ( HIGH , " hmx-mm-q: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu " ,
m , k , n , weight_type , m_chunk_n_rows , n_chunk_n_cols , vtcm_used , vtcm_budget ) ;
TIMER_DEFINE ( activation_load ) ;
TIMER_DEFINE ( weight_load ) ;
@ -1178,184 +938,115 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
TIMER_DEFINE ( total ) ;
TIMER_START ( total ) ;
FARF ( HIGH , " hmx_matmul_qk: %s mc=%zu nc=%zu vtcm=%zu/%zu " ,
use_pipeline ? " PIPELINE " : " SEQUENTIAL " , m_chunk_n_rows , n_chunk_n_cols ,
( size_t ) ( vtcm_ptr - ( uint8_t * ) ctx - > vtcm_base ) , vtcm_budget ) ;
// 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D)
// HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D).
if ( ! use_pipeline ) {
HAP_compute_res_hmx_lock ( ctx - > vtcm_rctx ) ;
for ( size_t mr = 0 ; mr < m ; mr + = m_chunk_n_rows ) {
// transfer activation matrix chunk into VTCM
const size_t n_rows = hex_smin ( m - mr , m_chunk_n_rows ) ;
const size_t n_row_tiles = hmx_ceil_div ( n_rows , HMX_FP16_TILE_N_ROWS ) ;
// A --> B: vtcm_qweight, 1 buffer
// B --> C: vtcm_weight0/vtcm_weight1, 2 buffers
// C --> D: vtcm_output0/vtcm_output1, 2 buffers
TIMER_START ( activation_load ) ;
{
const float * activation_chunk = activation + mr * k ;
transfer_activation_chunk_threaded ( ctx , vtcm_activation , activation_chunk , n_rows , k , k ) ;
}
TIMER_STOP ( activation_load ) ;
// Async timeline (C overlaps B+D):
// main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2]
// HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████]
void * buf_curr = vtcm_scratch0 ;
void * buf_next = vtcm_scratch1 ;
int n_chunk_cnt = hmx_ceil_div ( n , n_chunk_n_cols ) ;
hmx_matmul_job_t job_slots [ 2 ] ; // persistent double-buffered job descriptors
{
const size_t n_cols_first = hex_smin ( n , n_chunk_n_cols ) ;
dma_queue_push ( ctx - > dma [ 0 ] , dma_make_ptr ( buf_curr , permuted_weight ) , row_stride , row_stride , row_stride , n_cols_first ) ;
}
for ( size_t mr = 0 ; mr < m ; mr + = m_chunk_n_rows ) {
const size_t n_rows = hex_smin ( m - mr , m_chunk_n_rows ) ;
for ( size_t nc = 0 ; nc < n ; nc + = n_chunk_n_cols ) {
const size_t n_cols = hex_smin ( n - nc , n_chunk_n_cols ) ;
const size_t n_col_tiles = hmx_ceil_div ( n_cols , HMX_FP16_TILE_N_COLS ) ;
void * vtcm_qweight = vtcm_weight ;
void * vtcm_weight_bufs [ 2 ] = { vtcm_scratch0 , vtcm_scratch1 } ;
void * vtcm_output_bufs [ 2 ] = { vtcm_output , vtcm_scratch2 } ;
TIMER_START ( weight_load ) ;
{
dma_queue_pop ( ctx - > dma [ 0 ] ) ; // wait until current weight chunk become ready
const size_t nc_next = nc + n_chunk_n_cols ;
if ( nc_next < n ) {
const size_t n_cols_next = hex_smin ( n - nc_next , n_chunk_n_cols ) ;
const uint8_t * next_weight_chunk = permuted_weight + nc_next * row_stride ;
dma_queue_push ( ctx - > dma [ 0 ] , dma_make_ptr ( buf_next , next_weight_chunk ) , row_stride , row_stride , row_stride , n_cols_next ) ;
}
// Dequant + vscatter writes directly to [K, N] transposed tiles.
// HMX computes C = A x B, where A=[M,K] activation, B=[K,N] weight.
dequantize_x4x2_weight_chunk_to_fp16_tiles ( ctx , vtcm_weight , buf_curr , n_cols , k , row_stride , weight_type ) ;
hex_swap_ptr ( & buf_curr , & buf_next ) ;
}
TIMER_STOP ( weight_load ) ;
TIMER_START ( hmx_core ) ;
{
core_dot_chunk_fp16 ( vtcm_output , vtcm_activation , vtcm_weight , vtcm_scales , n_row_tiles , n_col_tiles , k / 32 ) ;
}
TIMER_STOP ( hmx_core ) ;
TIMER_START ( output_store ) ;
{
float * output = dst + ( mr * n + nc ) ;
transfer_output_chunk_threaded ( ctx , output , vtcm_output , n_rows , n_cols , n ) ;
}
TIMER_STOP ( output_store ) ;
}
// prologue: A0
const size_t n_cols_A0 = hex_smin ( n - 0 * n_chunk_n_cols , n_chunk_n_cols ) ;
{
const uint8_t * qweight_chunk_A0 = permuted_weight ;
dma_queue_push ( ctx - > dma [ 0 ] , dma_make_ptr ( vtcm_qweight , qweight_chunk_A0 ) , row_stride , row_stride , row_stride , n_cols_A0 ) ;
}
HAP_compute_res_hmx_unlock ( ctx - > vtcm_rctx ) ;
} else {
// 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D)
// HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D).
// A --> B: vtcm_qweight, 1 buffer
// B --> C: vtcm_weight0/vtcm_weight1, 2 buffers
// C --> D: vtcm_output0/vtcm_output1, 2 buffers
{
const float * activation_chunk = activation + mr * k ;
transfer_activation_chunk_threaded ( ctx , vtcm_activation , activation_chunk , n_rows , k , k ) ;
}
// Async timeline (C overlaps B+D):
// main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2]
// HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████]
// prologue: B0, A1, submit C0 (async), B1 (overlaps C0)
{
// B0: wait for DMA, dequant weight chunk 0
dma_queue_pop ( ctx - > dma [ 0 ] ) ;
dequantize_x4x2_weight_chunk_to_fp16_tiles ( ctx , vtcm_weight_bufs [ 0 ] , vtcm_qweight , n_cols_A0 , k , row_stride , weight_type ) ;
int n_chunk_cnt = hmx_ceil_div ( n , n_chunk_n_cols ) ;
hmx_matmul_job_t job_slots [ 2 ] ; // persistent double-buffered job descriptors
for ( size_t mr = 0 ; mr < m ; mr + = m_chunk_n_rows ) {
const size_t n_rows = hex_smin ( m - mr , m_chunk_n_rows ) ;
void * vtcm_qweight = vtcm_weight ;
void * vtcm_weight_bufs [ 2 ] = { vtcm_scratch0 , vtcm_scratch1 } ;
void * vtcm_output_bufs [ 2 ] = { vtcm_output , vtcm_scratch2 } ;
// prologue: A0
const size_t n_cols_A0 = hex_smin ( n - 0 * n_chunk_n_cols , n_chunk_n_cols ) ;
{
// Use 2D DMA (n_cols rows x row_stride) to avoid 16-bit roiwidth overflow.
const uint8_t * qweight_chunk_A0 = permuted_weight ;
dma_queue_push ( ctx - > dma [ 0 ] , dma_make_ptr ( vtcm_qweight , qweight_chunk_A0 ) , row_stride , row_stride , row_stride , n_cols_A0 ) ;
// A1: issue DMA for weight chunk 1
const size_t n_cols_A1 = hex_smin ( n - 1 * n_chunk_n_cols , n_chunk_n_cols ) ;
if ( 1 < n_chunk_cnt ) {
const uint8_t * qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride ;
dma_queue_push ( ctx - > dma [ 0 ] , dma_make_ptr ( vtcm_qweight , qweight_chunk_A1 ) , row_stride , row_stride , row_stride , n_cols_A1 ) ;
}
{
const float * activation_chunk = activation + mr * k ;
transfer_activation_chunk_threaded ( ctx , vtcm_activation , activation_chunk , n_rows , k , k ) ;
}
// submit C0 (non-blocking — HMX worker executes in parallel)
hmx_matmul_job_init ( & job_slots [ 0 ] , ( __fp16 * ) vtcm_output_bufs [ 0 ] , ( __fp16 * ) vtcm_activation ,
( __fp16 * ) vtcm_weight_bufs [ 0 ] , vtcm_scales ,
hmx_ceil_div ( n_rows , HMX_FP16_TILE_N_ROWS ) ,
hmx_ceil_div ( n_cols_A0 , HMX_FP16_TILE_N_COLS ) , k / HMX_FP16_TILE_N_ROWS ) ;
hmx_queue_push ( ctx - > hmx_queue , hmx_queue_make_desc ( hmx_matmul_worker_fn , & job_slots [ 0 ] ) ) ;
// prologue: B0, A1, submit C0 (async), B1 (overlaps C0)
{
// B0: wait for DMA, dequant weight chunk 0
// B1: DMA pop + dequant (runs in parallel with C0 on HMX worker)
if ( 1 < n_chunk_cnt ) {
dma_queue_pop ( ctx - > dma [ 0 ] ) ;
dequantize_x4x2_weight_chunk_to_fp16_tiles ( ctx , vtcm_weight_bufs [ 0 ] , vtcm_qweight , n_cols_A0 , k , row_stride , weight_type ) ;
// A1: issue DMA for weight chunk 1
const size_t n_cols_A1 = hex_smin ( n - 1 * n_chunk_n_cols , n_chunk_n_cols ) ;
if ( 1 < n_chunk_cnt ) {
const uint8_t * qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride ;
dma_queue_push ( ctx - > dma [ 0 ] , dma_make_ptr ( vtcm_qweight , qweight_chunk_A1 ) , row_stride , row_stride , row_stride , n_cols_A1 ) ;
}
// submit C0 (non-blocking — HMX worker executes in parallel)
hmx_matmul_job_init ( & job_slots [ 0 ] , ( __fp16 * ) vtcm_output_bufs [ 0 ] , ( __fp16 * ) vtcm_activation ,
( __fp16 * ) vtcm_weight_bufs [ 0 ] , vtcm_scales ,
hmx_ceil_div ( n_rows , HMX_FP16_TILE_N_ROWS ) ,
hmx_ceil_div ( n_cols_A0 , HMX_FP16_TILE_N_COLS ) , k / HMX_FP16_TILE_N_ROWS ) ;
hmx_queue_push ( ctx - > hmx_queue , hmx_queue_make_desc ( hmx_matmul_worker_fn , & job_slots [ 0 ] ) ) ;
// B1: DMA pop + dequant (runs in parallel with C0 on HMX worker)
if ( 1 < n_chunk_cnt ) {
dma_queue_pop ( ctx - > dma [ 0 ] ) ;
dequantize_x4x2_weight_chunk_to_fp16_tiles ( ctx , vtcm_weight_bufs [ 1 ] , vtcm_qweight , n_cols_A1 , k , row_stride , weight_type ) ;
}
}
// main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1})
for ( int i = 0 ; i < n_chunk_cnt ; + + i ) {
const size_t nc = i * n_chunk_n_cols ;
const size_t nc_p1 = nc + 1 * n_chunk_n_cols ;
const size_t nc_p2 = nc + 2 * n_chunk_n_cols ;
const size_t n_cols = hex_smin ( n - nc , n_chunk_n_cols ) ;
const size_t n_cols_p1 = hex_smin ( n - nc_p1 , n_chunk_n_cols ) ;
const size_t n_cols_p2 = hex_smin ( n - nc_p2 , n_chunk_n_cols ) ;
// issue A_{i+2}: DMA push (non-blocking)
if ( i + 2 < n_chunk_cnt ) {
const uint8_t * qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride ;
dma_queue_push ( ctx - > dma [ 0 ] , dma_make_ptr ( vtcm_qweight , qweight_chunk_p2 ) , row_stride , row_stride , row_stride , n_cols_p2 ) ;
}
// wait C_i: block until prologue/previous C completes
hmx_queue_pop ( ctx - > hmx_queue ) ;
// submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below)
// job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's
// counterpart — and (i+1)%2 was last used by C_{i-1} which completed
// before C_i was submitted.
if ( i + 1 < n_chunk_cnt ) {
hmx_matmul_job_init ( & job_slots [ ( i + 1 ) % 2 ] , ( __fp16 * ) vtcm_output_bufs [ ( i + 1 ) % 2 ] ,
( __fp16 * ) vtcm_activation , ( __fp16 * ) vtcm_weight_bufs [ ( i + 1 ) % 2 ] ,
vtcm_scales , hmx_ceil_div ( n_rows , HMX_FP16_TILE_N_ROWS ) ,
hmx_ceil_div ( n_cols_p1 , HMX_FP16_TILE_N_COLS ) , k / HMX_FP16_TILE_N_ROWS ) ;
hmx_queue_push ( ctx - > hmx_queue , hmx_queue_make_desc ( hmx_matmul_worker_fn , & job_slots [ ( i + 1 ) % 2 ] ) ) ;
}
// D_i: store output (multi-thread HVX, parallel with C_{i+1})
float * output_chunk = dst + ( mr * n + nc ) ;
transfer_output_chunk_threaded ( ctx , output_chunk , vtcm_output_bufs [ i % 2 ] , n_rows , n_cols , n ) ;
// B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1})
if ( i + 2 < n_chunk_cnt ) {
dma_queue_pop ( ctx - > dma [ 0 ] ) ;
dequantize_x4x2_weight_chunk_to_fp16_tiles ( ctx , vtcm_weight_bufs [ ( i + 2 ) % 2 ] , vtcm_qweight , n_cols_p2 , k , row_stride , weight_type ) ;
}
dequantize_x4x2_weight_chunk_to_fp16_tiles ( ctx , vtcm_weight_bufs [ 1 ] , vtcm_qweight , n_cols_A1 , k , row_stride , weight_type ) ;
}
}
hmx_queue_suspend ( ctx - > hmx_queue ) ;
// main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1})
for ( int i = 0 ; i < n_chunk_cnt ; + + i ) {
const size_t nc = i * n_chunk_n_cols ;
const size_t nc_p1 = nc + 1 * n_chunk_n_cols ;
const size_t nc_p2 = nc + 2 * n_chunk_n_cols ;
const size_t n_cols = hex_smin ( n - nc , n_chunk_n_cols ) ;
const size_t n_cols_p1 = hex_smin ( n - nc_p1 , n_chunk_n_cols ) ;
const size_t n_cols_p2 = hex_smin ( n - nc_p2 , n_chunk_n_cols ) ;
// issue A_{i+2}: DMA push (non-blocking)
if ( i + 2 < n_chunk_cnt ) {
const uint8_t * qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride ;
dma_queue_push ( ctx - > dma [ 0 ] , dma_make_ptr ( vtcm_qweight , qweight_chunk_p2 ) , row_stride , row_stride , row_stride , n_cols_p2 ) ;
}
// wait C_i: block until prologue/previous C completes
hmx_queue_pop ( ctx - > hmx_queue ) ;
// submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below)
// job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's
// counterpart — and (i+1)%2 was last used by C_{i-1} which completed
// before C_i was submitted.
if ( i + 1 < n_chunk_cnt ) {
hmx_matmul_job_init ( & job_slots [ ( i + 1 ) % 2 ] , ( __fp16 * ) vtcm_output_bufs [ ( i + 1 ) % 2 ] ,
( __fp16 * ) vtcm_activation , ( __fp16 * ) vtcm_weight_bufs [ ( i + 1 ) % 2 ] ,
vtcm_scales , hmx_ceil_div ( n_rows , HMX_FP16_TILE_N_ROWS ) ,
hmx_ceil_div ( n_cols_p1 , HMX_FP16_TILE_N_COLS ) , k / HMX_FP16_TILE_N_ROWS ) ;
hmx_queue_push ( ctx - > hmx_queue , hmx_queue_make_desc ( hmx_matmul_worker_fn , & job_slots [ ( i + 1 ) % 2 ] ) ) ;
}
// D_i: store output (multi-thread HVX, parallel with C_{i+1})
float * output_chunk = dst + ( mr * n + nc ) ;
transfer_output_chunk_threaded ( ctx , output_chunk , vtcm_output_bufs [ i % 2 ] , n_rows , n_cols , n ) ;
// B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1})
if ( i + 2 < n_chunk_cnt ) {
dma_queue_pop ( ctx - > dma [ 0 ] ) ;
dequantize_x4x2_weight_chunk_to_fp16_tiles ( ctx , vtcm_weight_bufs [ ( i + 2 ) % 2 ] , vtcm_qweight , n_cols_p2 , k , row_stride , weight_type ) ;
}
}
}
hmx_queue_suspend ( ctx - > hmx_queue ) ;
TIMER_STOP ( total ) ;
# if defined(ENABLE_PROFILE_TIMERS)
FARF ( HIGH , " %s: %lld us, m=%d k=%d n=%d pipeline=%d " , __func__ , TIMER_US ( total ) , m , k , n , use_pipeline ) ;
FARF ( HIGH , " hex-mm-q: %lld us : m %d k %d n %d " , TIMER_US ( total ) , m , k , n ) ;
if ( ! use_pipeline ) {
FARF ( HIGH , " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us " ,
TIMER_US ( activation_load ) , TIMER_US ( weight_load ) , TIMER_US ( hmx_core ) , TIMER_US ( output_store ) ) ;
@ -1370,15 +1061,15 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
//
static inline int hmx_matmul_batch_r2 ( const hmx_matmul_ w16a 32_batched_params_t * params ) {
static inline int hmx_matmul_batch_r2 ( const hmx_matmul_ f16_f 32_batched_params_t * params ) {
return params - > ne02 > 0 ? params - > ne12 / params - > ne02 : 1 ;
}
static inline int hmx_matmul_batch_r3 ( const hmx_matmul_ w16a 32_batched_params_t * params ) {
static inline int hmx_matmul_batch_r3 ( const hmx_matmul_ f16_f 32_batched_params_t * params ) {
return params - > ne03 > 0 ? params - > ne13 / params - > ne03 : 1 ;
}
static inline const __fp16 * hmx_matmul_weight_batch_ptr ( const hmx_matmul_ w16a 32_batched_params_t * params ,
static inline const __fp16 * hmx_matmul_weight_batch_ptr ( const hmx_matmul_ f16_f 32_batched_params_t * params ,
int dst_b2 , int dst_b3 ) {
const int r2 = hmx_matmul_batch_r2 ( params ) ;
const int r3 = hmx_matmul_batch_r3 ( params ) ;
@ -1387,37 +1078,36 @@ static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_
( size_t ) ( dst_b3 / r3 ) * params - > src0_nb3 ) ;
}
static inline const float * hmx_matmul_activation_batch_ptr ( const hmx_matmul_ w16a 32_batched_params_t * params ,
static inline const float * hmx_matmul_activation_batch_ptr ( const hmx_matmul_ f16_f 32_batched_params_t * params ,
int dst_b2 , int dst_b3 ) {
return ( const float * ) ( ( const uint8_t * ) params - > activation +
( size_t ) dst_b2 * params - > src1_nb2 +
( size_t ) dst_b3 * params - > src1_nb3 ) ;
}
static inline float * hmx_matmul_dst_batch_ptr ( const hmx_matmul_ w16a 32_batched_params_t * params ,
static inline float * hmx_matmul_dst_batch_ptr ( const hmx_matmul_ f16_f 32_batched_params_t * params ,
int dst_b2 , int dst_b3 ) {
return ( float * ) ( ( uint8_t * ) params - > dst +
( size_t ) dst_b2 * params - > dst_nb2 +
( size_t ) dst_b3 * params - > dst_nb3 ) ;
}
static int hmx_mat _mul_permuted_w16a 32_batched_legacy( struct htp_context * ctx ,
const hmx_matmul_ w16a 32_batched_params_t * params ) {
static int hmx_mat mul_f16_f 32_batched_legacy( struct htp_context * ctx ,
const hmx_matmul_ f16_f 32_batched_params_t * params ) {
int ret = 0 ;
for ( int b3 = 0 ; b3 < params - > ne13 & & ret = = 0 ; + + b3 ) {
for ( int b2 = 0 ; b2 < params - > ne12 & & ret = = 0 ; + + b2 ) {
ret = hmx_mat_mul_permuted_w16a32 ( ctx ,
hmx_matmul_dst_batch_ptr ( params , b2 , b3 ) ,
hmx_matmul_activation_batch_ptr ( params , b2 , b3 ) ,
hmx_matmul_weight_batch_ptr ( params , b2 , b3 ) ,
params - > m , params - > k , params - > n ,
params - > act_stride , params - > weight_stride ) ;
ret = hmx_matmul_f16_f32 ( ctx , hmx_matmul_dst_batch_ptr ( params , b2 , b3 ) ,
hmx_matmul_activation_batch_ptr ( params , b2 , b3 ) ,
hmx_matmul_weight_batch_ptr ( params , b2 , b3 ) ,
params - > m , params - > k , params - > n ,
params - > act_stride , params - > weight_stride ) ;
}
}
return ret ;
}
int hmx_mat _mul_permuted_w16a 32_batched( struct htp_context * ctx , const hmx_matmul_ w16a 32_batched_params_t * params ) {
int hmx_mat mul_f16_f 32_batched( struct htp_context * ctx , const hmx_matmul_ f16_f 32_batched_params_t * params ) {
if ( ! ctx | | ! params | | ! params - > dst | | ! params - > activation | | ! params - > permuted_weight ) { return - 1 ; }
if ( ! params - > m | | ! params - > k | | ! params - > n ) { return - 1 ; }
if ( params - > act_stride < params - > k | | params - > weight_stride < params - > k | | params - > dst_stride < params - > n ) { return - 1 ; }
@ -1435,7 +1125,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
if ( group_size < = 1 ) {
FARF ( HIGH , " %s: no dim2 GQA reuse (group=%d), using legacy batched loop " , __func__ , group_size ) ;
return hmx_mat _mul_permuted_w16a 32_batched_legacy( ctx , params ) ;
return hmx_mat mul_f16_f 32_batched_legacy( ctx , params ) ;
}
// Grouped path: reuse interleaved weight across all q_heads sharing a
@ -1464,7 +1154,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
/*m_block_cost=*/ ( size_t ) params - > n ,
/*n_block_cost=*/ ( size_t ) params - > m , & m_chunk_n_rows , & n_chunk_n_cols , & vtcm_used ) ! = 0 ) {
FARF ( HIGH , " %s: grouped path does not fit VTCM, falling back to legacy batched loop " , __func__ ) ;
return hmx_mat _mul_permuted_w16a 32_batched_legacy( ctx , params ) ;
return hmx_mat mul_f16_f 32_batched_legacy( ctx , params ) ;
}
const size_t act_head_stride = m_chunk_n_rows * ( size_t ) params - > k ; // fp16 elements between heads
@ -1486,7 +1176,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
if ( ( size_t ) ( vtcm_ptr - ( uint8_t * ) ctx - > vtcm_base ) > vtcm_budget ) {
FARF ( HIGH , " %s: grouped layout overflowed VTCM, falling back to legacy batched loop " , __func__ ) ;
return hmx_mat _mul_permuted_w16a 32_batched_legacy( ctx , params ) ;
return hmx_mat mul_f16_f 32_batched_legacy( ctx , params ) ;
}
hmx_init_column_scales ( vtcm_scales , Q6_V_vsplat_R ( 0x3c00 ) ) ; // scale: 1.0, bias: 0.0 in FP16
@ -1614,7 +1304,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
//
int hmx_mat _mul_permuted_w16a 32( struct htp_context * ctx , float * restrict dst , const float * restrict activation ,
int hmx_mat mul_f16_f 32( struct htp_context * ctx , float * restrict dst , const float * restrict activation ,
const __fp16 * restrict permuted_weight , int m , int k , int n ,
int act_stride , int weight_stride ) {
if ( ! dst | | ! activation | | ! permuted_weight | | ! m | | ! n | | ! k ) { return - 1 ; }