From aa50b2c2ae91326d5aad956ceeb015d1d48e626b Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Wed, 27 May 2026 10:46:11 -0700 Subject: [PATCH] hexagon: add support for Q4_1 in MUL_MAT and MUL_MAT_ID (#23647) * hex-mm: add support for Q4_1 matmul/matvec, hvx-only for now * hmx-mm: add support for Q4_1 * hex-mm: use Q8_1 dynamic quantization to avoid having to compute sums in the vec_dot * hexagon: fix repack scratch buffer overflow * hex-mm: fix Q4_1 repack buffer sizing * hexagon: flip the build order for mm and fa (seems to help LTO) * hex-mm: add vec_dot 4x1s and minor HMX cleanup after adding Q4_1 * hex-mm: fix fp16 vec_dot fallback to 2x1 and another issue that could cause incorrect output * hexagon: resurrect early-wake and add support for polling for op-batch completions With Q4_1 ggml-hexagon now claims pretty much the entire graphs which gives the CPU more time to chilax. This is a good thing! But it does add extra latency for the pure benchmark runs. Early wakeup helps recover the latency a bit in the normals runs and op-batch polling is just for benchmarking. --------- Co-authored-by: Todor Boinovski --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 267 +++- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 4 +- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 151 +- ggml/src/ggml-hexagon/htp/htp-ops.h | 2 + ggml/src/ggml-hexagon/htp/main.c | 7 +- ggml/src/ggml-hexagon/htp/matmul-ops.c | 1517 +++++++++++++++++++- scripts/snapdragon/adb/run-completion.sh | 5 +- scripts/snapdragon/adb/run-tool.sh | 11 +- 8 files changed, 1890 insertions(+), 74 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 1c8ecc197..5e8a4a740 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -68,6 +68,7 @@ static u32vec opt_pmu_evt { 0x3, 0x111, 0x100, 0x105, 0x240, 0x256, 0x7D, 0x8C } static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE; static int opt_opbatch = 1024; // max number of ops in a batch static int opt_opqueue = 16; // max number of pending batches +static int opt_oppoll = 0; // polling for batch completions static std::regex* opt_opfilter = NULL; // regex of ops to not claim @@ -550,7 +551,7 @@ static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size) size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) // Ensure we don't try to read more data than is available in the source buffer 'data' // or write more than the tensor can hold. @@ -611,7 +612,7 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size) size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) // Ensure we don't try to copy more data than the tensor actually contains. const size_t total_tensor_size = (size_t)nrows * row_size; @@ -660,6 +661,239 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size) ggml_aligned_free(buf_rp, row_size_rp); } +static void unpack_q4_1_quants(uint8_t * qs, const block_q4_1 * x, unsigned int bi) { + static const int qk = QK4_1; + + for (unsigned int i = 0; i < qk / 2; ++i) { + const int x0 = (x->qs[i] & 0x0F); + const int x1 = (x->qs[i] >> 4); + qs[bi * qk + i + 0] = x0; + qs[bi * qk + i + qk / 2] = x1; + } +} + +static void pack_q4_1_quants(block_q4_1 * x, const uint8_t * qs, unsigned int bi) { + static const int qk = QK4_1; + + for (unsigned int i = 0; i < qk / 2; ++i) { + const uint8_t x0 = qs[bi * qk + i + 0]; + const uint8_t x1 = qs[bi * qk + i + qk / 2]; + x->qs[i] = x0 | (x1 << 4); + } +} + +static void repack_row_q4_1x4x2(uint8_t * y, const block_q4_1 * x, int64_t k) { + static const int qk = QK_Q4_0x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers + + const int dblk_size = 8 * 4; // 8x (d, m) __fp16 = 32 bytes + const int qblk_size = qk / 2; // int4 = 128 bytes + const int qrow_size = k / 2; // int4 (not padded to blocks) + + uint8_t * y_q = y + 0; // quants first + uint8_t * y_d = y + qrow_size; // then scales/offsets + + // Repack the quants + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_Q4_0x4x2]; // unpacked quants + unpack_q4_1_quants(qs, &x[i * 8 + 0], 0); + unpack_q4_1_quants(qs, &x[i * 8 + 1], 1); + unpack_q4_1_quants(qs, &x[i * 8 + 2], 2); + unpack_q4_1_quants(qs, &x[i * 8 + 3], 3); + unpack_q4_1_quants(qs, &x[i * 8 + 4], 4); + unpack_q4_1_quants(qs, &x[i * 8 + 5], 5); + unpack_q4_1_quants(qs, &x[i * 8 + 6], 6); + unpack_q4_1_quants(qs, &x[i * 8 + 7], 7); + + bool partial = (nloe && i == nb-1); + + uint8_t * q = y_q + (i * qblk_size); + for (int j = 0; j < qk / 2; j++) { + q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000]; + } + } + + // Repack the scales and offsets + for (int i = 0; i < nb; i++) { + ggml_half * d_m = (ggml_half *) (y_d + i * dblk_size); + for (int j = 0; j < 8; j++) { + d_m[j * 2 + 0] = x[i * 8 + j].d; + d_m[j * 2 + 1] = x[i * 8 + j].m; + } + } +} + +static void unpack_row_q4_1x4x2(block_q4_1 * x, const uint8_t * y, int64_t k) { + static const int qk = QK_Q4_0x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers + + const int dblk_size = 8 * 4; // 8x (d, m) __fp16 = 32 bytes + const int qblk_size = qk / 2; // int4 = 128 bytes + const int qrow_size = k / 2; // int4 (not padded to blocks) + + const uint8_t * y_q = y + 0; // quants first + const uint8_t * y_d = y + qrow_size; // then scales/offsets + + // Unpack the quants + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_Q4_0x4x2]; + bool partial = (nloe && i == nb-1); + + const uint8_t * q = y_q + (i * qblk_size); + for (int j = 0; j < qk / 2; j++) { + if (partial) { + qs[j*2+0] = q[j] & 0x0F; + qs[j*2+1] = q[j] >> 4; + } else { + qs[j+000] = q[j] & 0x0F; + qs[j+128] = q[j] >> 4; + } + } + + pack_q4_1_quants(&x[i * 8 + 0], qs, 0); + pack_q4_1_quants(&x[i * 8 + 1], qs, 1); + pack_q4_1_quants(&x[i * 8 + 2], qs, 2); + pack_q4_1_quants(&x[i * 8 + 3], qs, 3); + pack_q4_1_quants(&x[i * 8 + 4], qs, 4); + pack_q4_1_quants(&x[i * 8 + 5], qs, 5); + pack_q4_1_quants(&x[i * 8 + 6], qs, 6); + pack_q4_1_quants(&x[i * 8 + 7], qs, 7); + } + + // Unpack the scales and offsets + for (int i = 0; i < nb; i++) { + const ggml_half * d_m = (const ggml_half *) (y_d + i * dblk_size); + for (int j = 0; j < 8; j++) { + x[i * 8 + j].d = d_m[j * 2 + 0]; + x[i * 8 + j].m = d_m[j * 2 + 1]; + } + } +} + +static void init_row_q4_1x4x2(block_q4_1 * x, int64_t k) { + static const int qk = QK_Q4_0x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + + uint8_t qs[QK_Q4_0x4x2]; // unpacked quants + memset(qs, 0, sizeof(qs)); + + for (int i = 0; i < nb; i++) { + pack_q4_1_quants(&x[i * 8 + 0], qs, 0); + pack_q4_1_quants(&x[i * 8 + 1], qs, 1); + pack_q4_1_quants(&x[i * 8 + 2], qs, 2); + pack_q4_1_quants(&x[i * 8 + 3], qs, 3); + pack_q4_1_quants(&x[i * 8 + 4], qs, 4); + pack_q4_1_quants(&x[i * 8 + 5], qs, 5); + pack_q4_1_quants(&x[i * 8 + 6], qs, 6); + pack_q4_1_quants(&x[i * 8 + 7], qs, 7); + } + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < 8; j++) { + x[i * 8 + j].d = 0; + x[i * 8 + j].m = 0; + } + } +} + +static void repack_q4_1_q4x4x2(ggml_tensor * t, const void * data, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) + + const size_t total_tensor_size = (size_t)nrows * row_size; + const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; + + const int64_t n_full_rows = n_bytes_to_copy / row_size; + const size_t n_rem_bytes = n_bytes_to_copy % row_size; + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + HEX_VERBOSE("ggml-hex: repack-q4_1-q4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, + t->ne[0], nrows, row_size); + + init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]); + + for (int64_t i = 0; i < n_full_rows; i++) { + const uint8_t * src = (const uint8_t *) data + (i * row_size); + uint8_t * dst = (uint8_t *) t->data + (i * row_size); + + memcpy(buf_pd, src, row_size); + repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, row_size); + } + + if (n_rem_bytes > 0) { + const int64_t i = n_full_rows; + const uint8_t * src = (const uint8_t *) data + (i * row_size); + uint8_t * dst = (uint8_t *) t->data + (i * row_size); + + init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]); + memcpy(buf_pd, src, n_rem_bytes); + repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, n_rem_bytes); + } + + ggml_aligned_free(buf_pd, row_size_pd); + ggml_aligned_free(buf_rp, row_size_rp); +} + +static void repack_q4x4x2_q4_1(void * data, const ggml_tensor * t, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) + + const size_t total_tensor_size = (size_t)nrows * row_size; + const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; + + const int64_t n_full_rows = n_bytes_to_copy / row_size; + const size_t n_rem_bytes = n_bytes_to_copy % row_size; + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + HEX_VERBOSE("ggml-hex: repack-q4x4x2-q4_1 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, + t->ne[0], nrows, row_size); + + memset(buf_rp, 0, row_size_rp); // clear-out padded buffer to make sure the tail is all zeros + + for (int64_t i = 0; i < n_full_rows; i++) { + const uint8_t * src = (const uint8_t *) t->data + (i * row_size); + uint8_t * dst = (uint8_t *) data + (i * row_size); + + memcpy(buf_rp, src, row_size); + unpack_row_q4_1x4x2((block_q4_1 *) buf_pd, (const uint8_t *) buf_rp, t->ne[0]); + memcpy(dst, buf_pd, row_size); + } + + if (n_rem_bytes > 0) { + const int64_t i = n_full_rows; + const uint8_t * src = (const uint8_t *) t->data + (i * row_size); + uint8_t * dst = (uint8_t *) data + (i * row_size); + + // We still need to read and unpack the entire source row because quantization is block-based. + memcpy(buf_rp, src, row_size); + unpack_row_q4_1x4x2((block_q4_1 *) buf_pd, (const uint8_t *) buf_rp, t->ne[0]); + memcpy(dst, buf_pd, n_rem_bytes); + } + + ggml_aligned_free(buf_pd, row_size_pd); + ggml_aligned_free(buf_rp, row_size_rp); +} + // ======== Q8x4x2 ==================== static void dump_block_q8_0(const block_q8_0 * b, int i) { HEX_VERBOSE("ggml-hex: repack q8_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, b->qs[0], b->qs[1], b->qs[2], @@ -876,7 +1110,7 @@ static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size) size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size quants + scales) // Ensure we don't try to read more data than is available in the source buffer 'data' // or write more than the tensor can hold. @@ -937,7 +1171,7 @@ static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size) size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size quants + scales) // Ensure we don't try to copy more data than the tensor actually contains. const size_t total_tensor_size = (size_t)nrows * row_size; @@ -1238,7 +1472,7 @@ static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t si size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) // Ensure we don't try to read more data than is available in the source buffer 'data' // or write more than the tensor can hold. @@ -1299,7 +1533,7 @@ static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t si size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) // Ensure we don't try to copy more data than the tensor actually contains. const size_t total_tensor_size = (size_t)nrows * row_size; @@ -1365,6 +1599,12 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, repack_q4_0_q4x4x2(tensor, data, size); break; + case GGML_TYPE_Q4_1: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + repack_q4_1_q4x4x2(tensor, data, size); + break; + case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); @@ -1407,6 +1647,12 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer, repack_q4x4x2_q4_0(data, tensor, size); break; + case GGML_TYPE_Q4_1: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + repack_q4x4x2_q4_1(data, tensor, size); + break; + case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); @@ -1886,7 +2132,8 @@ void ggml_hexagon_session::flush_pending(bool all) { uint32_t n_dbufs; // Read response packet from queue - int err = dspqueue_read(this->queue, &flags, 1, &n_dbufs, &dbuf, sizeof(rsp), &rsp_size, (uint8_t *) &rsp, DSPQUEUE_TIMEOUT); + const uint32_t timeo = opt_oppoll ? 0 : DSPQUEUE_TIMEOUT; + int err = dspqueue_read(this->queue, &flags, 1, &n_dbufs, &dbuf, sizeof(rsp), &rsp_size, (uint8_t *) &rsp, timeo); if (err == AEE_EEXPIRED) { continue; } @@ -2327,6 +2574,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s switch (src0->type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_Q8_0: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: @@ -2377,6 +2625,7 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session switch (src0->type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_Q8_0: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: @@ -3622,6 +3871,8 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { // Basic sanity checks to make sure definitions match static_assert((unsigned int) HTP_TYPE_Q4_0 == (unsigned int) GGML_TYPE_Q4_0, "please update hexagon_type to match ggml_type"); + static_assert((unsigned int) HTP_TYPE_Q4_1 == (unsigned int) GGML_TYPE_Q4_1, + "please update hexagon_type to match ggml_type"); static_assert((unsigned int) HTP_TYPE_Q8_0 == (unsigned int) GGML_TYPE_Q8_0, "please update hexagon_type to match ggml_type"); static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4, @@ -3634,6 +3885,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { const char * str_opstage = getenv("GGML_HEXAGON_OPSTAGE"); const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH"); const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE"); + const char * str_oppoll = getenv("GGML_HEXAGON_OPPOLL"); const char * str_opfilter = getenv("GGML_HEXAGON_OPFILTER"); const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); const char * str_etm = getenv("GGML_HEXAGON_ETM"); @@ -3671,6 +3923,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage; opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; + opt_oppoll = str_oppoll ? strtoul(str_oppoll, NULL, 0) : opt_oppoll; opt_profile = str_profile ? atoi(str_profile) : 0; opt_etm = str_etm ? atoi(str_etm) : 0; opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 33d67dda9..d7927261a 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -59,14 +59,14 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) if (_hmx_idx GREATER_EQUAL 0) target_sources(${HTP_LIB} PRIVATE hmx-queue.c - hmx-matmul-ops.c hmx-flash-attn-ops.c + hmx-matmul-ops.c ) # -mhmx enables HMX instruction set (needed by files that include hmx-utils.h) set_source_files_properties( - hmx-matmul-ops.c hmx-flash-attn-ops.c + hmx-matmul-ops.c PROPERTIES COMPILE_OPTIONS "-mhmx" ) diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 3ef0bcdb2..ab5fd7338 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -34,6 +34,10 @@ static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { -8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, }; +static const __fp16 q4_1_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, +}; + // MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value // kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6 static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { @@ -62,6 +66,8 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) { case HTP_TYPE_Q4_0: case HTP_TYPE_IQ4_NL: return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb + case HTP_TYPE_Q4_1: + return (size_t) nb * (QK_Q4_0x4x2 / 2 + 32); // 160 * nb case HTP_TYPE_Q8_0: return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb case HTP_TYPE_MXFP4: @@ -233,6 +239,54 @@ static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx( return r; } +static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale_offset, const HVX_Vector vlut_cvt) { + HVX_Vector vq = hvx_vmemu(packed_32); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_dm = hvx_vmemu(scale_offset); + HVX_Vector v_scales = hvx_vec_repl_f16(v_dm); + HVX_Vector v_offsets = hvx_vec_repl_f16(Q6_V_vror_VR(v_dm, 2)); + + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + v_quants = Q6_Vb_vshuff_Vb(v_quants); + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_hf = Q6_V_lo_W(vp); + + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales), v_offsets)); +} + +static inline HVX_Vector_x2 dequantize_x4x2_q4_1_x4groups_hvx( + const uint8_t *packed_128, bool upper_nibbles, + const __fp16 *scales_offsets_4, const HVX_Vector vlut_cvt) { + HVX_Vector vq = hvx_vmemu(packed_128); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + v_quants = Q6_Vb_vshuff_Vb(v_quants); + + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_lo = Q6_V_lo_W(vp); + HVX_Vector v_hi = Q6_V_hi_W(vp); + + HVX_Vector vscale_offset = hvx_vmemu(scales_offsets_4); + HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(vscale_offset, vscale_offset, -2); + HVX_Vector vd = Q6_V_lo_W(dm_deal); + HVX_Vector vm = Q6_V_hi_W(dm_deal); + + HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vd); + HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vd, 4)); + + HVX_Vector v_os01 = hvx_vec_repl_2x_f16(vm); + HVX_Vector v_os23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vm, 4)); + + v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01), v_os01)); + v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23), v_os23)); + + HVX_Vector_x2 r = { v_lo, v_hi }; + return r; +} + // Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) { HVX_Vector vq = hvx_vmemu(quants_32); @@ -331,11 +385,13 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( int start_tile, int end_tile) { const int n_k_tiles = (unsigned)k_block / HMX_FP16_TILE_N_COLS; - const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL); + const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_Q4_1 || weight_type == HTP_TYPE_IQ4_NL); + const bool is_q4_1 = (weight_type == HTP_TYPE_Q4_1); const int qrow_size = is_q4 ? ((unsigned)k_block / 2) : k_block; const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) : (weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) : + (weight_type == HTP_TYPE_Q4_1) ? hvx_vmem(q4_1_to_fp16_lut) : hvx_vmem(q4_0_to_fp16_lut); // vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions. @@ -356,8 +412,10 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4 bool upper = (sub_blk_base >= 4); unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes - unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE - + sub_blk_base * (int)sizeof(__fp16); // 4 consecutive scales + unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE; + unsigned scale_step = is_q4_1 ? 4 : (int)sizeof(__fp16); + unsigned scale_off = qrow_size + blk_idx * dblk_size + + sub_blk_base * scale_step; __fp16 *tile_bases[4]; for (unsigned g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; } @@ -367,20 +425,38 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride; unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { - const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; - const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; + if (is_q4_1) { + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { + const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; + const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; - HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); - HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); + HVX_Vector_x2 dv0 = dequantize_x4x2_q4_1_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); + HVX_Vector_x2 dv1 = dequantize_x4x2_q4_1_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + } else { + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { + const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; + const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; + + HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); + HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); + + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } } for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); } @@ -446,26 +522,43 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; bool upper = (sub_blk >= 4); unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; - unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); + unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE; + unsigned scale_step = is_q4_1 ? 4 : (int)sizeof(__fp16); + unsigned scale_off = qrow_size + blk_idx * dblk_size + sub_blk * scale_step; HVX_Vector v_off = v_scat_base; // reset to column 0 unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride; unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { - const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; - const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; + if (is_q4_1) { + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { + const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; + const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; - HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx( - r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); - HVX_Vector v1 = (row1 < n_cols) - ? dequantize_x4x2_q4_0_group_hvx( - r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) - : Q6_V_vzero(); + HVX_Vector v0 = dequantize_x4x2_q4_1_group_hvx(r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); + HVX_Vector v1 = (row1 < n_cols) + ? dequantize_x4x2_q4_1_group_hvx(r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) + : Q6_V_vzero(); - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + } else { + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { + const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; + const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; + + HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx(r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); + HVX_Vector v1 = (row1 < n_cols) + ? dequantize_x4x2_q4_0_group_hvx(r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) + : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } } (void) *(volatile HVX_Vector *)(tile_base); } else if (weight_type == HTP_TYPE_MXFP4) { @@ -593,6 +686,8 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles( // --- End x4x2 dequantizers --- +#pragma clang diagnostic ignored "-Wbackend-plugin" // spurios warning for hmx intrinsics + // requires external HMX lock static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict activation, const __fp16 *restrict weight, const __fp16 *restrict scales, int n_row_tiles, int n_col_tiles, int n_dot_tiles) { diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 54cfadd9b..aadc77235 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -20,6 +20,7 @@ enum htp_data_type { HTP_TYPE_F32 = 0, HTP_TYPE_F16 = 1, HTP_TYPE_Q4_0 = 2, + HTP_TYPE_Q4_1 = 3, HTP_TYPE_Q8_0 = 8, HTP_TYPE_IQ4_NL = 20, HTP_TYPE_I32 = 26, @@ -28,6 +29,7 @@ enum htp_data_type { // types used internally for repack, dyn.quant, etc HTP_TYPE_Q4_0x4x2 = 200, + HTP_TYPE_Q4_1x4x2, HTP_TYPE_Q8_0x4x2, HTP_TYPE_MXFP4x4x2, diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index f3a0866c7..7dd90ac7d 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -853,6 +853,11 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { for (uint32_t i=0; i < n_ops; i++) { struct profile_data prof; + if (i == (n_ops-1)) { + // wake up the host before starting the last op + dspqueue_write_early_wakeup_noblock(queue, 0, 0); + } + profile_start(ctx->profiler, &prof); proc_op_req(octx, tens, i, &ops[i]); @@ -869,8 +874,6 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { } } - // dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0); - struct htp_opbatch_rsp rsp; rsp.id = req.id; rsp.status = HTP_STATUS_OK; diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 46fc5602d..7036c491b 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -40,6 +40,11 @@ struct htp_matmul_context { const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1); + void (*vec_dot_4x1)(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0); + // Precomputed values uint32_t src0_nrows_per_thread; uint32_t src1_nrows_per_thread; @@ -155,6 +160,13 @@ static inline size_t q8x4x2_row_size(uint32_t ne) { return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128); } +static inline size_t q8_1x4x2_row_size(uint32_t ne) { + // ensures perfect alignment of quants and full row + const uint32_t qk = QK_Q8_0x4x2; + const uint32_t nb = (ne + qk - 1) / qk; + return hex_round_up(ne + nb * 8 * 2 * sizeof(__fp16), 128); +} + static inline HVX_Vector_x8 hvx_vec_load_q4x4x8_full(const uint8_t * restrict ptr) { const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; @@ -223,6 +235,62 @@ static HVX_Vector_x8 hvx_vec_load_q4x4x8_partial(const uint8_t * restrict ptr, u return r; } +static inline HVX_Vector_x8 hvx_vec_load_q4_1x4x8_full(const uint8_t * restrict ptr) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) + HVX_Vector v2_3 = vptr[1]; // ... + HVX_Vector v4_5 = vptr[2]; // ... + HVX_Vector v6_7 = vptr[3]; // ... + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + + HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 : second 128 elements + HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F ... + HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 + HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F + HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 + HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F + HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 + + HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; + return r; +} + +static HVX_Vector_x8 hvx_vec_load_q4_1x4x8_partial(const uint8_t * restrict ptr, uint32_t n) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + const uint32_t qk = QK_Q4_0x4x2; // 256 + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + + HVX_Vector_x8 r; + uint32_t i = 0; + + #pragma unroll(2) + for (i=0; i < nb; i++) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements + r.v[i*2+0] = v0; + r.v[i*2+1] = v1; + } + + if (nloe) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements + HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... + r.v[i*2+0] = Q6_V_lo_W(v0_1_p); + r.v[i*2+1] = Q6_V_hi_W(v0_1_p); + } + + return r; +} + static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_full(const uint8_t * restrict ptr) { const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; @@ -401,6 +469,607 @@ static inline HVX_Vector hvx_vec_rmpy_x8_partial(HVX_Vector_x8 x, HVX_Vector_x8 return hvx_vec_rmpy_x8_partial(x, y, 512); } +static void vec_dot_q4_1x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales/offsets + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elemements + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(dm, dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(dm_deal)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(dm, dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(dm_deal)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + + // Zero out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ms = Q6_V_vand_QV(bmask, r0_ms); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); + } + + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); + hvx_vec_store_u(s0, 4, r0_sum); +} + +static void vec_dot_q4_1x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elemements + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); + + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); + + // Zero out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ms = Q6_V_vand_QV(bmask, r0_ms); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r1_ms = Q6_V_vand_QV(bmask, r1_ms); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); + + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(s0, 8, rsum); +} + +static void vec_dot_q4_1x4x2_q8x4x2_4x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_q4_1x4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_q4_1x4x8_full(r3_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); + + HVX_Vector r2_dm = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); + HVX_VectorPair r2_dm_deal = Q6_W_vdeal_VVR(r2_dm, r2_dm, -2); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r2_dm_deal)); + HVX_Vector r2_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r2_dm_deal)); + + HVX_Vector r3_dm = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); + HVX_VectorPair r3_dm_deal = Q6_W_vdeal_VVR(r3_dm, r3_dm, -2); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r3_dm_deal)); + HVX_Vector r3_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r3_dm_deal)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); + + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r2_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_m, vy_s))); + + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + HVX_Vector r3_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_m, vy_s))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); + + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); + + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r2_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_ms); + + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + HVX_Vector r3_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_ms); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa_total, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa_total, r3_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_q4_1x4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_q4_1x4x8_partial(r3_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); + + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); + + HVX_Vector r2_dm = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); + HVX_VectorPair r2_dm_deal = Q6_W_vdeal_VVR(r2_dm, r2_dm, -2); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r2_dm_deal)); + HVX_Vector r2_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r2_dm_deal)); + + HVX_Vector r3_dm = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); + HVX_VectorPair r3_dm_deal = Q6_W_vdeal_VVR(r3_dm, r3_dm, -2); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r3_dm_deal)); + HVX_Vector r3_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r3_dm_deal)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); + + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r2_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_m, vy_s))); + + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + HVX_Vector r3_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_m, vy_s))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ms = Q6_V_vand_QV(bmask, r0_ms); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r1_ms = Q6_V_vand_QV(bmask, r1_ms); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r2_ms = Q6_V_vand_QV(bmask, r2_ms); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); + r3_ms = Q6_V_vand_QV(bmask, r3_ms); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); + + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); + + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r2_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_ms); + + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + HVX_Vector r3_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_ms); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa_total, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa_total, r3_sum)); + } + + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); +} + + +static void vec_dot_q4_1x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales/sums + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales/sums + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + // Load src1 columns + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); + + // Load src0 rows + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector ds0 = *(const HVX_UVector *) (y0_d + i * y_dblk_size); + HVX_VectorPair ds0_deal = Q6_W_vdeal_VVR(ds0, ds0, -2); + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds0_deal)); + HVX_Vector vy0_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds0_deal)); + + HVX_Vector ds1 = *(const HVX_UVector *) (y1_d + i * y_dblk_size); + HVX_VectorPair ds1_deal = Q6_W_vdeal_VVR(ds1, ds1, -2); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds1_deal)); + HVX_Vector vy1_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds1_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy0_s))); + + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy1_s))); + + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy0_s))); + + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy1_s))); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + HVX_Vector r0_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_ms); + HVX_Vector r0_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_ms); + HVX_Vector r1_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_ms); + HVX_Vector r1_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_ms); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa_total, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa_total, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa_total, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa_total, r1_c1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); + + HVX_Vector ds0 = *(const HVX_UVector *) (y0_d + i * y_dblk_size); + HVX_VectorPair ds0_deal = Q6_W_vdeal_VVR(ds0, ds0, -2); + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds0_deal)); + HVX_Vector vy0_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds0_deal)); + + HVX_Vector ds1 = *(const HVX_UVector *) (y1_d + i * y_dblk_size); + HVX_VectorPair ds1_deal = Q6_W_vdeal_VVR(ds1, ds1, -2); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds1_deal)); + HVX_Vector vy1_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds1_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy0_s))); + + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy1_s))); + + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy0_s))); + + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy1_s))); + + // Zero out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c0_ms = Q6_V_vand_QV(bmask, r0_c0_ms); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r0_c1_ms = Q6_V_vand_QV(bmask, r0_c1_ms); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c0_ms = Q6_V_vand_QV(bmask, r1_c0_ms); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r1_c1_ms = Q6_V_vand_QV(bmask, r1_c1_ms); + + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + HVX_Vector r0_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_ms); + HVX_Vector r0_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_ms); + HVX_Vector r1_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_ms); + HVX_Vector r1_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_ms); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa_total, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa_total, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa_total, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa_total, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size assert((unsigned long) vx0 % 128 == 0); @@ -570,6 +1239,133 @@ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, hvx_vec_store_u(s0, 8, rsum); } +static void vec_dot_q4x4x2_q8x4x2_4x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_q4x4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_q4x4x8_full(r3_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_q4x4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_q4x4x8_partial(r3_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); +} + + static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1) { @@ -869,6 +1665,133 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, hvx_vec_store_u(s0, 8, rsum); } +static void vec_dot_q8x4x2_q8x4x2_4x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (qf32) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_q8x4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_q8x4x8_full(r3_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_q8x4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_q8x4x8_partial(r3_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); +} + + static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1) { @@ -1163,6 +2086,135 @@ static void vec_dot_iq4nlx4x2_q8x4x2_2x1(const int n, hvx_vec_store_u(s0, 8, rsum); } +static void vec_dot_iq4nlx4x2_q8x4x2_4x1(const int n, + float * restrict s0, + const void * restrict vx0, + const void * restrict vx1, + const void * restrict vx2, + const void * restrict vx3, + const void * restrict vy0) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_iq4nlx4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_iq4nlx4x8_full(r3_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_iq4nlx4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_iq4nlx4x8_partial(r3_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); +} + + static void vec_dot_iq4nlx4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, @@ -1521,6 +2573,176 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, hvx_vec_store_u(s0, 8, rsum); } +static void vec_dot_mxfp4x4x2_q8x4x2_4x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_MXFP4x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 + const uint32_t x_qblk_size = qk / 2; // fp4 + const uint32_t x_qrow_size = n / 2; // fp4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_mxfp4x4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_mxfp4x4x8_full(r3_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_Vector r2_d = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); + HVX_Vector r3_d = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); + vy_d = Q6_Vsf_equals_Vqf32(vy_d); + + // Convert rX_d scales from e8m0 to fp32 + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + r2_d = Q6_V_vdelta_VV(r2_d, expand); + r2_d = Q6_V_vand_VV(r2_d, e8m0_mask); + r2_d = Q6_Vw_vasl_VwR(r2_d, 23); + r3_d = Q6_V_vdelta_VV(r3_d, expand); + r3_d = Q6_V_vand_VV(r3_d, e8m0_mask); + r3_d = Q6_Vw_vasl_VwR(r3_d, 23); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r2_d, vy_d)); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r3_d, vy_d)); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_mxfp4x4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_mxfp4x4x8_partial(r3_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_Vector r2_d = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); + HVX_Vector r3_d = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); + vy_d = Q6_Vsf_equals_Vqf32(vy_d); + + // Convert rX_d scales from e8m0 to fp32 + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + r2_d = Q6_V_vdelta_VV(r2_d, expand); + r2_d = Q6_V_vand_VV(r2_d, e8m0_mask); + r2_d = Q6_Vw_vasl_VwR(r2_d, 23); + r3_d = Q6_V_vdelta_VV(r3_d, expand); + r3_d = Q6_V_vand_VV(r3_d, e8m0_mask); + r3_d = Q6_Vw_vasl_VwR(r3_d, 23); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r2_d, vy_d)); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r3_d, vy_d)); + + // Zero-out unused values + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); +} + + static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1) { @@ -2138,7 +3360,6 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); - const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); // no work for this thread if (src0_start_row >= src0_end_row) { @@ -2168,39 +3389,89 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { const uint8_t * restrict src1_col = (const uint8_t *) src1_data; float * restrict dst_col = (float *) dst->data; - // Prefill spad with 2x src0 rows - #pragma unroll(2) - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint32_t is0 = (ir0 - src0_start_row); - if (is0 >= MM_SPAD_SRC0_NROWS) { - break; + if (mmctx->vec_dot_4x1 != NULL) { + const uint32_t src0_end_row_x4 = src0_start_row + ((src0_end_row - src0_start_row) & ~3U); + + // Prefill spad with 4x src0 rows + #pragma unroll(4) + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) { + const uint32_t is0 = (ir0 - src0_start_row); + if (is0 >= MM_SPAD_SRC0_NROWS) { + break; + } + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 4); } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 2); - } - // Process src0 rows - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); + // Process src0 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_4x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, ss0 + 2 * src0_stride, ss0 + 3 * src0_stride, src1_col); - // Prefetch next (n + spad_nrows) row - const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); - const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + // Prefetch next (n + spad_nrows) row + const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + if (pr0 < src0_end_row_x4) { + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, 4); + } + } + + // Process leftovers + uint32_t ir0 = src0_end_row_x4; + if (ir0 + 2 <= src0_end_row) { + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 2); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); + ir0 += 2; + } + if (ir0 < src0_end_row) { + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 1); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + ir0 += 1; + } + } else { + const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); + + // Prefill spad with 2x src0 rows + #pragma unroll(2) + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint32_t is0 = (ir0 - src0_start_row); + if (is0 >= MM_SPAD_SRC0_NROWS) { + break; + } + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), src0_stride, src0_row_size, 2); } - } - // Process the last row (if any) - if (src0_end_row != src0_end_row_x2) { - const uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row); - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 1); - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + // Process src0 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); + + // Prefetch next (n + spad_nrows) row + const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + if (pr0 < src0_end_row_x2) { + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, 2); + } + } + + // Process the last row (if any) + if (src0_end_row != src0_end_row_x2) { + const uint32_t ir0 = src0_end_row_x2; + const uint32_t is0 = (ir0 - src0_start_row); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 1); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + } } hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row); @@ -2432,6 +3703,94 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) { // *** dynamic quant +static inline void quantize_block_f32_q8_1x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { + assert((unsigned long) x % 128 == 0); + assert((unsigned long) y_q % 128 == 0); + + HVX_Vector * vx = (HVX_Vector *) x; + HVX_Vector zero = Q6_V_vzero(); + + // Use reduce max fp32 to find max(abs(e)) first + HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); + HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); + HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); + HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); + + // Load and convert into QF32 + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements + + // Convert to QF32 + HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); + HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); + HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); + HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); + + // Combine and convert to fp16 + HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); + HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); + + // Convert into fp16 + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); + HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); + + // Divide input by the scale + HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); + HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); + + // Convert to int8 + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + *(HVX_Vector *) y_q = vx_i8; + + // --- Sum calculation --- + const HVX_Vector ones = Q6_Vb_vsplat_R(1); + HVX_Vector v_sums = Q6_Vw_vrmpy_VbVb(vx_i8, ones); // sum every 4 consecutive elements + // Sum 8 elements: + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 4)); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 8)); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 16)); + + // Copy to stack to extract sums and vmaxes + float vmax0[32] __attribute__((aligned(128))); + float vmax1[32] __attribute__((aligned(128))); + float vmax2[32] __attribute__((aligned(128))); + float vmax3[32] __attribute__((aligned(128))); + int32_t sums[32] __attribute__((aligned(128))); + + hvx_vec_store_u(vmax0, 128, vmax0_sf); + hvx_vec_store_u(vmax1, 128, vmax1_sf); + hvx_vec_store_u(vmax2, 128, vmax2_sf); + hvx_vec_store_u(vmax3, 128, vmax3_sf); + hvx_vec_store_u(sums, 128, v_sums); + + float d0 = vmax0[0] / 127.0f; + float d1 = vmax1[0] / 127.0f; + float d2 = vmax2[0] / 127.0f; + float d3 = vmax3[0] / 127.0f; + + __fp16 * y_d_half = (__fp16 *) y_d; + y_d_half[0] = d0; + y_d_half[1] = (float) sums[0] * d0; + y_d_half[2] = d1; + y_d_half[3] = (float) sums[8] * d1; + y_d_half[4] = d2; + y_d_half[5] = (float) sums[16] * d2; + y_d_half[6] = d3; + y_d_half[7] = (float) sums[24] * d3; +} + static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { assert((unsigned long) x % 128 == 0); assert((unsigned long) y_q % 128 == 0); @@ -2656,6 +4015,77 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } +static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { + assert(k % 32 == 0); + const uint32_t qk = QK_Q8_0x4x2; + const uint32_t nb = (k + qk - 1) / qk; + + const uint32_t qrow_size = k; // int8 + + const uint32_t dblk_size = 8 * 4; // 8x (d, s) __fp16 = 32 bytes + const uint32_t qblk_size = QK_Q8_0x4x2; // int8 + + uint8_t * restrict y_q = (y + 0); // quants first + uint8_t * restrict y_d = (y + qrow_size); // then scales/sums + + // Temp scales override input since we're working off of the aligned temp buffer in VTCM + uint8_t * restrict t_d = (uint8_t *) x; + + for (uint32_t i = 0; i < nb; i++) { + quantize_block_f32_q8_1x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_f32_q8_1x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); + } + + // now copy the scales/sums into final location + hvx_copy_f16_ua(y_d, t_d, nb * 16); +} + +static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = octx->src[1]; + uint8_t * restrict dst = octx->src1_spad.data; + struct htp_spad * spad = &octx->src0_spad; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = src->nb[1]; + const size_t dst_row_size = q8_1x4x2_row_size(ne0); + + uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first); + uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith); + + const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float)); + memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding + + for (uint32_t i = ir_first; i < ir_last; ++i) { + hex_l2fetch(src_data, src_row_size, src_row_size, 2); + hvx_copy_f32_aa(tmp_data, src_data, ne0); + + quantize_row_f32_q8_1x4x2((float *) tmp_data, dst_data, ne0); + dst_data += dst_row_size; + src_data += src_row_size; + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "quantize-f32-q8_1x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) { struct htp_matmul_context * mmctx = data; struct htp_ops_context * octx = mmctx->octx; @@ -2751,24 +4181,35 @@ static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_t mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1; mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_q4x4x2_q8x4x2_4x1; + return 0; + case HTP_TYPE_Q4_1: + mmctx->type = "q4_1x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_q4_1x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_q4_1x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_q4_1x4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_q4_1x4x2_q8x4x2_4x1; return 0; case HTP_TYPE_Q8_0: mmctx->type = "q8x4x2-f32"; mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1; mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_q8x4x2_q8x4x2_4x1; return 0; case HTP_TYPE_IQ4_NL: mmctx->type = "iq4nlx4x2-f32"; mmctx->vec_dot_1x1 = vec_dot_iq4nlx4x2_q8x4x2_1x1; mmctx->vec_dot_2x1 = vec_dot_iq4nlx4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_iq4nlx4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_iq4nlx4x2_q8x4x2_4x1; return 0; case HTP_TYPE_MXFP4: mmctx->type = "mxfp4x4x2-f32"; mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1; mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_mxfp4x4x2_q8x4x2_4x1; return 0; default: return -1; @@ -2894,8 +4335,13 @@ static int op_matmul_hvx(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - quant_job_func = quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); + if (src0->type == HTP_TYPE_Q4_1) { + quant_job_func = quantize_f32_q8_1x4x2; + src1_row_size = q8_1x4x2_row_size(ne10); + } else { + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); + } htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0); } @@ -2962,7 +4408,7 @@ int op_matmul(struct htp_ops_context * octx) { // HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. // Other types fall back to HVX. uint32_t wtype = src0->type; - if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { + if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q4_1 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { return op_matmul_hvx(octx); } @@ -3098,8 +4544,13 @@ int op_matmul_id(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - quant_job_func = quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); + if (src0->type == HTP_TYPE_Q4_1) { + quant_job_func = quantize_f32_q8_1x4x2; + src1_row_size = q8_1x4x2_row_size(ne10); + } else { + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); + } const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread); diff --git a/scripts/snapdragon/adb/run-completion.sh b/scripts/snapdragon/adb/run-completion.sh index 0d9b10271..fe14bb142 100755 --- a/scripts/snapdragon/adb/run-completion.sh +++ b/scripts/snapdragon/adb/run-completion.sh @@ -51,6 +51,9 @@ opbatch= opqueue= [ "$OQ" != "" ] && opqueue="GGML_HEXAGON_OPQUEUE=$OQ" +oppoll= +[ "$OP" != "" ] && oppoll="GGML_HEXAGON_OPPOLL=$OP" + opflt= [ "$OF" != "" ] && opflt="GGML_HEXAGON_OPFILTER=$OF" @@ -66,7 +69,7 @@ adb $adbserial $adbhost shell " \ cd $basedir; ulimit -c unlimited; \ LD_LIBRARY_PATH=$basedir/$branch/lib \ ADSP_LIBRARY_PATH=$basedir/$branch/lib \ - $verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $opflt $vmem $mbuf \ + $verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $vmem $mbuf \ ./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \ --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ --ctx-size 8192 --ubatch-size 1024 -fa on \ diff --git a/scripts/snapdragon/adb/run-tool.sh b/scripts/snapdragon/adb/run-tool.sh index 27cbb2b6d..6d7e32b32 100755 --- a/scripts/snapdragon/adb/run-tool.sh +++ b/scripts/snapdragon/adb/run-tool.sh @@ -42,6 +42,15 @@ ndev= hb= [ "$HB" != "" ] && hb="GGML_HEXAGON_HOSTBUF=$HB" +opbatch= +[ "$OB" != "" ] && opbatch="GGML_HEXAGON_OPBATCH=$OB" + +opqueue= +[ "$OQ" != "" ] && opqueue="GGML_HEXAGON_OPQUEUE=$OQ" + +oppoll= +[ "$OP" != "" ] && oppoll="GGML_HEXAGON_OPPOLL=$OP" + set -x tool=$1; shift @@ -50,5 +59,5 @@ adb $adbserial $adbhost shell " \ cd $basedir; ulimit -c unlimited; \ LD_LIBRARY_PATH=$basedir/$branch/lib \ ADSP_LIBRARY_PATH=$basedir/$branch/lib \ - $verbose $sched $opmask $profile $nhvx $hmx $ndev $hb ./$branch/bin/$tool $@ \ + $verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll ./$branch/bin/$tool $@ \ "