// Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu.h // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- // vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi // // Copyright 2024 Mozilla Foundation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // // ██████╗ ██╗ █████╗ ██████╗ // ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║ ██╔══██╗██╔═══╝ // ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║ ███████║██████╗ // ██║ ██║██▀███║╚███╔╝██╔══██╗██║ ██╔══██║╔═══██║ // ██║ ██║██║ ██║ ███║ ██████╔╝████╗██║ ██║██████║ // ╚═╝ ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝ ╚═╝╚═════╝ // // BASIC LINEAR ALGEBRA SUBPROGRAMS // // // This file implements multithreaded CPU matrix multiplication for the // common contiguous use case C = Aᵀ * B. These kernels are designed to // have excellent performance[1] for matrices that fit in the CPU cache // without imposing any overhead such as cache filling or malloc calls. // // This implementation does not guarantee any upper bound with rounding // errors, which grow along with k. Our goal's to maximally exploit the // hardware for performance, and then use whatever resources remain for // improving numerical accuracy. // // [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online]. // Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024]. #pragma once #include "llama.cpp/ggml-impl.h" #include "llama.cpp/ggml-quants.h" // #include "log.h" #include "flags.h" #include "sgemm.h" // #include #pragma GCC diagnostic ignored "-Wpedantic" #pragma GCC diagnostic ignored "-Wignored-attributes" #define ROW_ALIGN 64 #define MATRIX_ALIGN 4096 #define MAX_ALIGN 4096 #ifdef _MSC_VER #define NOINLINE __declspec(noinline) #else #define NOINLINE __attribute__((__noinline__)) #endif #if defined(__ARM_NEON) || defined(__AVX512F__) #define VECTOR_REGISTERS 32 #else #define VECTOR_REGISTERS 16 #endif #if 0 #define NOT_SUPPORTED tinyBLAS_not_supported(__FILE__, __LINE__) #else #define NOT_SUPPORTED false #endif #define WANT_QUANTIZATION false namespace { bool tinyBLAS_not_supported(const char* file, int line) { // tinylogf("%s:%d: tinyBLAS not supported\n", file, line); return false; } inline float unhalf(ggml_fp16_t d) { return GGML_FP16_TO_FP32(d); } inline float unhalf(ggml_bf16_t d) { return GGML_BF16_TO_FP32(d); } //////////////////////////////////////////////////////////////////////////////////////////////////// // MATRIX MEMORY INDEXING #define NCA 1 #define NCB 2 #define NCC 4 #define INDEX(A, lda, j, i) (CONFIG & NC##A ? ((T##A**)A)[j] + i : A + lda * (j) + i) //////////////////////////////////////////////////////////////////////////////////////////////////// // GGML TYPE TRAITS template struct ggml_type_trait; template <> struct ggml_type_trait { static constexpr ggml_type id = GGML_TYPE_F32; }; template <> struct ggml_type_trait { static constexpr ggml_type id = GGML_TYPE_BF16; }; template <> struct ggml_type_trait { static constexpr ggml_type id = GGML_TYPE_F16; }; template <> struct ggml_type_trait { static constexpr ggml_type id = GGML_TYPE_Q8_0; }; //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED ARITHMETIC OPERATIONS #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); } inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); } inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); } #endif // __SSE__ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); } inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); } inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); } #endif // __AVX__ #if defined(__AVX512F__) inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); } inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); } inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); } #endif // __AVX512F__ #if defined(__ARM_NEON) inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); } inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); } inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); } #endif // __ARM_NEON #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); } inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); } inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); } #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED FUSED MULTIPLY ADD /** * Computes a * b + c. */ template inline U madd(T a, T b, U c) { return add(mul(a, b), c); } /** * Computes a * b + c with error correction. * * @see W. Kahan, "Further remarks on reducing truncation errors," * Communications of the ACM, vol. 8, no. 1, p. 40, Jan. 1965, * doi: 10.1145/363707.363723. */ template inline U madder(T a, T b, U c, U* e) { U y = sub(mul(a, b), *e); U t = add(c, y); *e = sub(sub(t, c), y); return t; } #ifdef __ARM_NEON inline float32x4_t badder(float32x4_t a, float b, float32x4_t c, float32x4_t* e) { float32x4_t y = sub(vmulq_n_f32(a, b), *e); float32x4_t t = add(c, y); *e = sub(sub(t, c), y); return t; } #endif #if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))) #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) template <> inline __m256 madd(__m256 a, __m256 b, __m256 c) { return _mm256_fmadd_ps(a, b, c); } #endif #if defined(__AVX512F__) template <> inline __m512 madd(__m512 a, __m512 b, __m512 c) { return _mm512_fmadd_ps(a, b, c); } #endif #endif #if defined(__ARM_FEATURE_FMA) template <> inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) { return vfmaq_f32(c, a, b); } #if 0 // todo: this specialization chops gcc 12.3 performance in half #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) && 0 template <> inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) { return vfmaq_f16(c, b, a); } #endif #endif #endif #if defined(__AVX512BF16__) template <> inline __m512 madd(__m512bh x, __m512bh y, __m512 z) { return _mm512_dpbf16_ps(z, x, y); } template <> inline __m512 madder(__m512bh x, __m512bh y, __m512 z, __m512* _) { return _mm512_dpbf16_ps(z, x, y); } #endif //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED HORIZONTAL SUM #if defined(__ARM_NEON) inline float hsum(float32x4_t x) { return vaddvq_f32(x); } #endif // __ARM_NEON #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) inline float hsum(float16x8_t x) { // todo: this works great on clang but it produces terrible code on gcc 12.3 return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)), vcvt_f32_f16(vget_high_f16(x)))); } #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) inline float hsum(__m128 x) { #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) x = _mm_add_ps(x, _mm_movehl_ps(x, x)); x = _mm_add_ss(x, _mm_movehdup_ps(x)); #else __m128 t; t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1)); x = _mm_add_ps(x, t); t = _mm_movehl_ps(t, x); x = _mm_add_ss(x, t); #endif return _mm_cvtss_f32(x); } #endif #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) inline float hsum(__m256 x) { return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x))); } #endif // __AVX__ #if defined(__AVX512F__) inline float hsum(__m512 x) { return _mm512_reduce_add_ps(x); } #endif // __AVX512F__ //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED MEMORY LOADING template T load(const U*); template <> inline float load(const float* p) { return *p; } template <> inline float load(const ggml_fp16_t* p) { return unhalf(*p); } template <> inline float load(const ggml_bf16_t* p) { return unhalf(*p); } #if defined(__ARM_NEON) template <> inline float32x4_t load(const float* p) { return vld1q_f32(p); } template <> inline float32x4_t load(const ggml_bf16_t* p) { return vreinterpretq_f32_u32(vshll_n_u16(vld1_u16((const unsigned short*)p), 16)); } #if !defined(_MSC_VER) template <> inline float16x8_t load(const ggml_fp16_t* p) { return vld1q_f16((const float16_t*)p); } template <> inline float32x4_t load(const ggml_fp16_t* p) { return vcvt_f32_f16(vld1_f16((const float16_t*)p)); } #endif // _MSC_VER #endif // __ARM_NEON #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) template <> inline __m128 load(const float* p) { return _mm_loadu_ps(p); } #endif // __SSE__ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) template <> inline __m256 load(const float* p) { return _mm256_loadu_ps(p); } #endif // __AVX__ #if defined(__AVX2__) || defined(__AVX512F__) template <> inline __m256 load(const ggml_bf16_t* p) { return _mm256_castsi256_ps( _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i*)p)), 16)); } #endif // __AVX2__ #if defined(__F16C__) template <> inline __m256 load(const ggml_fp16_t* p) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)p)); } #endif // __F16C__ #if defined(__AVX512F__) template <> inline __m512 load(const float* p) { return _mm512_loadu_ps(p); } template <> inline __m512 load(const ggml_fp16_t* p) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)p)); } template <> inline __m512 load(const ggml_bf16_t* p) { return _mm512_castsi512_ps( _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)p)), 16)); } #endif // __AVX512F__ #if defined(__AVX512BF16__) template <> inline __m512bh load(const ggml_bf16_t* p) { return (__m512bh)_mm512_loadu_ps((const float*)p); } template <> inline __m512bh load(const float* p) { return _mm512_cvtne2ps_pbh(_mm512_loadu_ps(p + 16), _mm512_loadu_ps(p)); } #endif // __AVX512BF16__ //////////////////////////////////////////////////////////////////////////////////////////////////// // FLOATING POINT OUTPUT STREAMING inline void store(float* p, float f) { *p = f; } inline void store(ggml_fp16_t* p, float f) { *p = GGML_FP32_TO_FP16(f); } inline void store(ggml_bf16_t* p, float f) { *p = GGML_FP32_TO_BF16(f); } //////////////////////////////////////////////////////////////////////////////////////////////////// // FLOATING POINT MATRIX MULTIPLICATION template class tinyBLAS { public: tinyBLAS(long k, const TA* A, long lda, const TB* B, long ldb, TC* C, long ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } void matmul(long m, long n, int task) { if (task == GGML_TASK_TYPE_COMPUTE) mnpack(0, m, 0, n); } private: NOINLINE void mnpack(long m0, long m, long n0, long n) { long mc, nc, mp, np; #if VECTOR_REGISTERS == 32 if (!FLAG_precise) { switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) { case 0x55: mc = 5; nc = 5; gemm<5, 5, false>(m0, m, n0, n); break; case 0x54: case 0x53: case 0x52: case 0x45: case 0x44: case 0x43: case 0x42: case 0x35: case 0x34: case 0x33: case 0x32: case 0x25: case 0x24: case 0x23: case 0x22: mc = 2; nc = 2; gemm<2, 2, false>(m0, m, n0, n); break; case 0x51: case 0x41: case 0x31: case 0x21: mc = 2; nc = 1; gemm<2, 1, false>(m0, m, n0, n); break; case 0x15: case 0x14: case 0x13: case 0x12: mc = 1; nc = 2; gemm<1, 2, false>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1, false>(m0, m, n0, n); break; default: return; } } else { switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 3)) { case 0x43: mc = 4; nc = 3; gemm<4, 3, true>(m0, m, n0, n); break; case 0x42: case 0x33: case 0x32: case 0x23: case 0x22: mc = 2; nc = 2; gemm<2, 2, true>(m0, m, n0, n); break; case 0x41: case 0x31: case 0x21: mc = 2; nc = 1; gemm<2, 1, true>(m0, m, n0, n); break; case 0x13: case 0x12: mc = 1; nc = 2; gemm<1, 2, true>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1, true>(m0, m, n0, n); break; default: return; } } #endif #if VECTOR_REGISTERS == 16 if (!FLAG_precise) { switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 3)) { case 0x43: mc = 4; nc = 3; gemm<4, 3, false>(m0, m, n0, n); break; case 0x42: case 0x33: case 0x32: case 0x23: case 0x22: mc = 2; nc = 2; gemm<2, 2, false>(m0, m, n0, n); break; case 0x41: case 0x31: case 0x21: mc = 2; nc = 1; gemm<2, 1, false>(m0, m, n0, n); break; case 0x13: case 0x12: mc = 1; nc = 2; gemm<1, 2, false>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1, false>(m0, m, n0, n); break; default: return; } } else { switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 2)) { case 0x32: mc = 3; nc = 2; gemm<3, 2, true>(m0, m, n0, n); break; case 0x23: mc = 2; nc = 3; gemm<2, 3, true>(m0, m, n0, n); break; case 0x22: mc = 2; nc = 2; gemm<2, 2, true>(m0, m, n0, n); break; case 0x31: case 0x21: mc = 2; nc = 1; gemm<2, 1, true>(m0, m, n0, n); break; case 0x12: mc = 1; nc = 2; gemm<1, 2, true>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1, true>(m0, m, n0, n); break; default: return; } } #endif mp = m0 + (m - m0) / mc * mc; np = n0 + (n - n0) / nc * nc; mnpack(mp, m, n0, np); mnpack(m0, m, np, n); } template NOINLINE void gemm(long m0, long m, long n0, long n) { long ytiles = RM > 1 ? (m - m0) / RM : 1; long xtiles = RN > 1 ? (n - n0) / RN : 1; long tiles = xtiles * ytiles; long duty = (tiles + nth - 1) / nth; long start = duty * ith; long end = start + duty; if (end > tiles) end = tiles; for (long job = start; job < end; ++job) { long ii = m0 + job / xtiles * RM; long jj = n0 + job % xtiles * RN; D Cv[RN][RM] = {}; D Ce[RN][RM] = {}; for (long l = 0; l < k; l += KN) #pragma GCC unroll 100 for (int j = 0; j < RN; ++j) #pragma GCC unroll 100 for (int i = 0; i < RM; ++i) if (PRECISE) Cv[j][i] = madder(load(INDEX(A, lda, ii + i, l)), // load(INDEX(B, ldb, jj + j, l)), // Cv[j][i], &Ce[j][i]); else Cv[j][i] = madd(load(INDEX(A, lda, ii + i, l)), // load(INDEX(B, ldb, jj + j, l)), // Cv[j][i]); #pragma GCC unroll 100 for (int j = 0; j < RN; ++j) #pragma GCC unroll 100 for (int i = 0; i < RM; ++i) store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i])); } } const TA* const A; const TB* const B; TC* const C; const long k; const long lda; const long ldb; const long ldc; const int ith; const int nth; }; ////////////////////////////////////////////////////////////////////////////////////////// // QUANT ZERO MATRIX MULTIPLICATION #if defined(__ARM_FEATURE_DOTPROD) template class tinyBLAS_Q0_ARM { public: tinyBLAS_Q0_ARM(long k, const TA* A, long lda, const TB* B, long ldb, TC* C, long ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } void matmul(long m, long n, int task) { if (task == GGML_TASK_TYPE_COMPUTE) mnpack(0, m, 0, n); } private: NOINLINE void mnpack(long m0, long m, long n0, long n) { long mc, nc, mp, np; if (!FLAG_precise) { switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) { case 0x33: mc = 3; nc = 3; gemm<3, 3, false>(m0, m, n0, n); break; case 0x32: case 0x23: case 0x22: mc = 2; nc = 2; gemm<2, 2, false>(m0, m, n0, n); break; case 0x31: case 0x21: mc = 2; nc = 1; gemm<2, 1, false>(m0, m, n0, n); break; case 0x13: case 0x12: mc = 1; nc = 2; gemm<1, 2, false>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1, false>(m0, m, n0, n); break; default: return; } } else { switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) { case 0x33: mc = 3; nc = 3; gemm<3, 3, true>(m0, m, n0, n); break; case 0x32: case 0x23: case 0x22: mc = 2; nc = 2; gemm<2, 2, true>(m0, m, n0, n); break; case 0x31: case 0x21: mc = 2; nc = 1; gemm<2, 1, true>(m0, m, n0, n); break; case 0x13: case 0x12: mc = 1; nc = 2; gemm<1, 2, true>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1, true>(m0, m, n0, n); break; default: return; } } mp = m0 + (m - m0) / mc * mc; np = n0 + (n - n0) / nc * nc; mnpack(mp, m, n0, np); mnpack(m0, m, np, n); } template NOINLINE void gemm(long m0, long m, long n0, long n) { long ytiles = RM > 1 ? (m - m0) / RM : 1; long xtiles = RN > 1 ? (n - n0) / RN : 1; long tiles = xtiles * ytiles; long duty = (tiles + nth - 1) / nth; long start = duty * ith; long end = start + duty; if (end > tiles) end = tiles; for (long job = start; job < end; ++job) { long ii = m0 + job / xtiles * RM; long jj = n0 + job % xtiles * RN; float32x4_t Cv[RN][RM] = {}; float32x4_t Ce[RN][RM] = {}; for (int l = 0; l < k; ++l) #pragma GCC unroll 100 for (int j = 0; j < RN; ++j) #pragma GCC unroll 100 for (int i = 0; i < RM; ++i) { float32x4_t a = vcvtq_f32_s32(vdotq_s32( vdotq_s32(vdupq_n_s32(0), load_lo(INDEX(A, lda, ii + i, l)), load_lo(INDEX(B, ldb, jj + j, l))), load_hi(INDEX(A, lda, ii + i, l)), load_hi(INDEX(B, ldb, jj + j, l)))); float b = unhalf(INDEX(A, lda, ii + i, l)->d) * unhalf(INDEX(B, ldb, jj + j, l)->d); if (PRECISE) Cv[j][i] = badder(a, b, Cv[j][i], &Ce[j][i]); else Cv[j][i] = vmlaq_n_f32(Cv[j][i], a, b); } #pragma GCC unroll 100 for (int j = 0; j < RN; ++j) #pragma GCC unroll 100 for (int i = 0; i < RM; ++i) store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i])); } } inline int8x16_t load_lo(const block_q8_0* b) { return vld1q_s8(b->qs); } inline int8x16_t load_hi(const block_q8_0* b) { return vld1q_s8(b->qs + 16); } inline int8x16_t load_lo(const block_q4_0* b) { return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs), vdupq_n_u8(0x0f))), vdupq_n_s8(0x8)); } inline int8x16_t load_hi(const block_q4_0* b) { return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)), vdupq_n_s8(0x8)); } const TA* const A; const TB* const B; TC* const C; const long k; const long lda; const long ldb; const long ldc; const int ith; const int nth; }; #endif // __ARM_FEATURE_DOTPROD #if defined(__AVX2__) || defined(__AVX512F__) template class tinyBLAS_Q0_AVX2 { public: tinyBLAS_Q0_AVX2(long k, const TA* A, long lda, const TB* B, long ldb, TC* C, long ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } void matmul(long m, long n, int task) { if (task == GGML_TASK_TYPE_COMPUTE) mnpack(0, m, 0, n); } private: void mnpack(long m0, long m, long n0, long n) { long mc, nc, mp, np; #if VECTOR_REGISTERS == 32 if (!FLAG_precise) { switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) { case 0x33: mc = 3; nc = 3; gemm<3, 3, false>(m0, m, n0, n); break; case 0x32: case 0x23: case 0x22: mc = 2; nc = 2; gemm<2, 2, false>(m0, m, n0, n); break; case 0x31: case 0x21: mc = 2; nc = 1; gemm<2, 1, true>(m0, m, n0, n); break; case 0x13: case 0x12: mc = 1; nc = 2; gemm<1, 2, true>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1, true>(m0, m, n0, n); break; default: return; } } else { switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) { case 0x33: mc = 3; nc = 3; gemm<3, 3, true>(m0, m, n0, n); break; case 0x32: case 0x23: case 0x22: mc = 2; nc = 2; gemm<2, 2, true>(m0, m, n0, n); break; case 0x31: case 0x21: mc = 2; nc = 1; gemm<2, 1, true>(m0, m, n0, n); break; case 0x13: case 0x12: mc = 1; nc = 2; gemm<1, 2, true>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1, true>(m0, m, n0, n); break; default: return; } } #endif #if VECTOR_REGISTERS == 16 if (!FLAG_precise) { switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 2)) { case 0x32: mc = 3; nc = 2; gemm<3, 2, false>(m0, m, n0, n); break; case 0x23: mc = 2; nc = 3; gemm<2, 3, false>(m0, m, n0, n); break; case 0x22: mc = 2; nc = 2; gemm<2, 2, false>(m0, m, n0, n); break; case 0x31: case 0x21: mc = 2; nc = 1; gemm<2, 1, false>(m0, m, n0, n); break; case 0x12: mc = 1; nc = 2; gemm<1, 2, false>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1, false>(m0, m, n0, n); break; default: return; } } else { switch ((MIN(m - m0, 2) << 4) | MIN(n - n0, 1)) { case 0x21: mc = 2; nc = 1; gemm<2, 1, true>(m0, m, n0, n); break; case 0x12: mc = 1; nc = 2; gemm<1, 2, true>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1, true>(m0, m, n0, n); break; default: return; } } #endif mp = m0 + (m - m0) / mc * mc; np = n0 + (n - n0) / nc * nc; mnpack(mp, m, n0, np); mnpack(m0, m, np, n); } template NOINLINE void gemm(long m0, long m, long n0, long n) { long ytiles = RM > 1 ? (m - m0) / RM : 1; long xtiles = RN > 1 ? (n - n0) / RN : 1; long tiles = xtiles * ytiles; long duty = (tiles + nth - 1) / nth; long start = duty * ith; long end = start + duty; if (end > tiles) end = tiles; for (long job = start; job < end; ++job) { long ii = m0 + job / xtiles * RM; long jj = n0 + job % xtiles * RN; __m256 Cv[RN][RM] = {}; __m256 Ce[RN][RM] = {}; for (long l = 0; l < k; ++l) #pragma GCC unroll 100 for (int j = 0; j < RN; ++j) #pragma GCC unroll 100 for (int i = 0; i < RM; ++i) { __m256 a = _mm256_set1_ps(unhalf(INDEX(A, lda, ii + i, l)->d) * unhalf(INDEX(B, ldb, jj + j, l)->d)); __m256 b = updot(_mm256_sign_epi8(load(INDEX(A, lda, ii + i, l)), load(INDEX(A, lda, ii + i, l))), _mm256_sign_epi8(load(INDEX(B, ldb, jj + j, l)), load(INDEX(A, lda, ii + i, l)))); if (PRECISE) Cv[j][i] = madder(a, b, Cv[j][i], &Ce[j][i]); else Cv[j][i] = madd(a, b, Cv[j][i]); } #pragma GCC unroll 100 for (int j = 0; j < RN; ++j) #pragma GCC unroll 100 for (int i = 0; i < RM; ++i) store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i])); } } inline __m256i load(const block_q8_0* b) { return _mm256_loadu_si256((const __m256i*)b->qs); } inline __m256i load(const block_q4_0* b) { __m128i x = _mm_loadu_si128((const __m128i*)b->qs); return _mm256_sub_epi8(_mm256_and_si256(_mm256_set1_epi8(15), _mm256_insertf128_si256(_mm256_castsi128_si256(x), _mm_srli_epi16(x, 4), 1)), _mm256_set1_epi8(8)); } inline __m256 updot(__m256i u, __m256i s) { __m256i res; #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s); #else res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s)); #endif return _mm256_cvtepi32_ps(res); } const TA* const A; const TB* const B; TC* const C; const long k; const long lda; const long ldb; const long ldc; const int ith; const int nth; }; #endif // __AVX2__ } // namespace