From 808aba39161e5d7ca2ff24110b5aa14d2e536988 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 11 Jul 2024 16:47:47 +0200 Subject: [PATCH 01/13] CUDA: optimize and refactor MMQ (#8416) * CUDA: optimize and refactor MMQ * explicit q8_1 memory layouts, add documentation --- ggml/src/ggml-cuda/mma.cuh | 4 + ggml/src/ggml-cuda/mmq.cuh | 1365 ++++++++++++++++--------------- ggml/src/ggml-cuda/quantize.cu | 107 ++- ggml/src/ggml-cuda/quantize.cuh | 6 +- ggml/src/ggml-cuda/vecdotq.cuh | 72 +- 5 files changed, 867 insertions(+), 687 deletions(-) diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 5d87dd8e6..a452a3cc3 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -70,6 +70,10 @@ struct mma_int_A_I16K8 { } #endif // defined(INT8_MMA_AVAILABLE) } + + __device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) { + ((mma_int_A_I16K4 *) x)[0].load(xs0, stride); + } }; struct mma_int_B_J8K4 { diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 118e34d28..51c44d857 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -8,18 +8,70 @@ #include #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. +#define MMQ_ITER_K 256 +#define MMQ_NWARPS 8 typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride); -typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0); +typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00); typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max); +enum mmq_q8_1_ds_layout { + MMQ_Q8_1_DS_LAYOUT_D4, + MMQ_Q8_1_DS_LAYOUT_DS4, + MMQ_Q8_1_DS_LAYOUT_D2S6, +}; + struct block_q8_1_mmq { - half2 ds[4]; - int8_t qs[4*QK8_1]; + // The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block. + // The y float data is first grouped as blocks of 128 values. + // These blocks are then treated as individual data values and transposed. + // + // To avoid shared memory bank conflicts each block is padded with 16 bytes. + // This padding is also used to store block scales/partial sums. + // The scales multiplied with the quantized data are equal to the unquantized values. + // The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization) + // and are only needed for performance reasons. + // + // The exact data stored depends on the x data type. + union { + float d4[4]; // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3 + half2 ds4[4]; // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3 + half d2s6[8]; // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values, + // stored as d0,d1,s1,s2,s3,s4,s5 + }; + int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each }; static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size"); static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size"); +static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { + switch (type_x) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q5_0: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q5_1: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q8_0: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q2_K: + return MMQ_Q8_1_DS_LAYOUT_D2S6; + case GGML_TYPE_Q3_K: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + return MMQ_Q8_1_DS_LAYOUT_D4; + default: + GGML_ASSERT(false); + break; + } +} + struct tile_x_sizes { int qs; int dm; @@ -79,49 +131,46 @@ static constexpr __device__ int get_mmq_y_device() { #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) } -#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} -#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0} -#define MMQ_DP4A_TXS_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0} -#define MMQ_DP4A_TXS_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0} -#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0} -#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0} -#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4} -#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} +#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0} +#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0} +#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0} +#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0} +#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 : type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 : - type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 : - type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 : + type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 : + type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 : type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 : type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K : type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K : type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K : type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K : type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K : - type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q5_0 : - type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q5_0 : + type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 : + type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 : tile_x_sizes{0, 0, 0}; } -#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4) -#define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1 + 4) -#define MMQ_MMA_TILE_X_K_Q5_0 (2*WARP_SIZE + WARP_SIZE/QI5_0 + 4) -#define MMQ_MMA_TILE_X_K_Q5_1 (2*WARP_SIZE + WARP_SIZE/QI5_1 + 4) -#define MMQ_MMA_TILE_X_K_Q8_0 (1*WARP_SIZE + WARP_SIZE/QI8_0 + 0) -#define MMQ_MMA_TILE_X_K_Q2_K (1*WARP_SIZE + WARP_SIZE + 4) -#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/QI3_K + WARP_SIZE/4 + 2) -#define MMQ_MMA_TILE_X_K_Q4_K (1*WARP_SIZE + WARP_SIZE/QI4_K + WARP_SIZE/8 + 7) -#define MMQ_MMA_TILE_X_K_Q5_K (2*WARP_SIZE + WARP_SIZE/QI5_K + WARP_SIZE/8 + 7) -#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7) +#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4) +#define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1 + 4) +#define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE + 4) +#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/(2*QI3_K) + WARP_SIZE/8 + 7) +#define MMQ_MMA_TILE_X_K_Q4_K (1*WARP_SIZE + WARP_SIZE/QI4_K + WARP_SIZE/8 + 7) +#define MMQ_MMA_TILE_X_K_Q5_K (2*WARP_SIZE + WARP_SIZE/QI5_K + WARP_SIZE/8 + 7) +#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7) static_assert(MMQ_MMA_TILE_X_K_Q4_0 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q4_1 % 8 == 4, "Wrong padding."); -static_assert(MMQ_MMA_TILE_X_K_Q5_0 % 8 == 4, "Wrong padding."); -static_assert(MMQ_MMA_TILE_X_K_Q5_1 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding."); @@ -131,21 +180,20 @@ static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 : type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 : - type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 : - type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 : + type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 : type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 : type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K : type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K : type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K : type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K : type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : - type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q5_0 : - type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q5_0 : + type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 : 0; } #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) -#define MMQ_NWARPS 8 static int mmq_get_granularity_host(const int mmq_x, const int cc) { return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8; @@ -218,7 +266,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); const int * x_qs = (const int *) x; @@ -226,34 +274,39 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) { + const int k0 = k00 + k01; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); - - int u[2*VDR_Q4_0_Q8_1_MMQ]; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; #pragma unroll - for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_0) % WARP_SIZE]; + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); + + int u[2*VDR_Q4_0_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; + u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)]; + } + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl + (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_0], u, + x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } - - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl - (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0], - y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } } template static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K8 mma_A; @@ -271,52 +324,60 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; - mma_A A[ntx]; - float dA[ntx][mma_C::ne/2]; + mma_A A[ntx][4]; + float dA[ntx][mma_C::ne/2][4]; const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + n*mma_A::I + mma_A::get_i(l); - const int k = k0 + mma_A::get_k(l) % QI4_0; - const int shift = 4*(mma_A::get_k(l) / QI4_0); - - A[n].x[l] = __vsubss4((x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + k] >> shift) & 0x0F0F0F0F, 0x08080808); - } + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*QI4_0) { + const int k0 = k00 + k01; #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + for (int l = 0; l < mma_A::ne; ++l) { + const int i = i0 + n*mma_A::I + mma_A::get_i(l); + const int k = k0/QR4_0 + mma_A::get_k(l) % QI4_0; + const int shift = 4*(mma_A::get_k(l) / QI4_0); - dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + k0/QI4_0]; + A[n][k01/(QR4_0*QI4_0)].x[l] = __vsubss4((x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + k] >> shift) & 0x0F0F0F0F, 0x08080808); + } + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + + dA[n][l][k01/(QR4_0*QI4_0)] = x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + k0/(QR4_0*QI4_0)]; + } } } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - mma_B B; - float dB[mma_C::ne/2]; +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*QI4_0) { + mma_B B; + float dB[mma_C::ne/2]; - B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K); + B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); - dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); - } + dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } #pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma_K8(A[n], B); + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n][k01/(QR4_0*QI4_0)], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l]; + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2][k01/(QR4_0*QI4_0)]*dB[l%2]*C.x[l]; + } } } } @@ -381,7 +442,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); const int * x_qs = (const int *) x; @@ -389,34 +450,39 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) { + const int k0 = k00 + k01; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); - - int u[2*VDR_Q4_1_Q8_1_MMQ]; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; #pragma unroll - for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_1) % WARP_SIZE]; + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); + + int u[2*VDR_Q4_1_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; + u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)]; + } + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl + (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_1], u, + x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } - - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl - (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1], - y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } } template static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K8 mma_A; @@ -435,50 +501,58 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; - mma_A A[ntx]; - half2 dmA[ntx][mma_C::ne/2]; + mma_A A[ntx][4]; + half2 dmA[ntx][mma_C::ne/2][4]; const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); #pragma unroll for (int n = 0; n < ntx; ++n) { - ((mma_A_K4 *) &A[n])[0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_1 + k0, MMQ_MMA_TILE_X_K_Q4_1); - A[n].x[2] = (A[n].x[0] >> 4) & 0x0F0F0F0F; - A[n].x[3] = (A[n].x[1] >> 4) & 0x0F0F0F0F; - A[n].x[0] &= 0x0F0F0F0F; - A[n].x[1] &= 0x0F0F0F0F; +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*QI4_1) { + const int k0 = k00 + k01; + + A[n][k01/(QR4_1*QI4_1)].load_low(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_1 + k0/QR4_1, MMQ_MMA_TILE_X_K_Q4_1); + A[n][k01/(QR4_1*QI4_1)].x[2] = (A[n][k01/(QR4_1*QI4_1)].x[0] >> 4) & 0x0F0F0F0F; + A[n][k01/(QR4_1*QI4_1)].x[3] = (A[n][k01/(QR4_1*QI4_1)].x[1] >> 4) & 0x0F0F0F0F; + A[n][k01/(QR4_1*QI4_1)].x[0] &= 0x0F0F0F0F; + A[n][k01/(QR4_1*QI4_1)].x[1] &= 0x0F0F0F0F; #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + k0/QI4_1]; + dmA[n][l][k01/(QR4_1*QI4_1)] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + k0/(QR4_1*QI4_1)]; + } } } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - mma_B B; - half2 dsB[mma_C::ne/2]; +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*QI4_1) { + mma_B B; + half2 dsB[mma_C::ne/2]; - B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K); + B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); - dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; - } + dsB[l] = y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]; + } #pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma_K8(A[n], B); + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n][k01/(QR4_1*QI4_1)], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2]; - sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); + for (int l = 0; l < mma_C::ne; ++l) { + const half2 dmA_dsB = dmA[n][l/2][k01/(QR4_1*QI4_1)]*dsB[l%2]; + sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); + } } } } @@ -531,8 +605,8 @@ template static __device__ __forceinlin qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; - x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; #else x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; @@ -553,106 +627,13 @@ template static __device__ __forceinlin const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; #ifdef INT8_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + kbxd] = bxi->d; + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d; #endif // INT8_MMA_AVAILABLE } } -template -static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + txs.qs; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl - (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE], - x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); - } - } -} - -template -static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { -#ifdef INT8_MMA_AVAILABLE - - typedef mma_int_A_I16K8 mma_A; - typedef mma_int_B_J8K8 mma_B; - typedef mma_int_C_I16J8 mma_C; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + WARP_SIZE*2; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - - mma_A A[ntx]; - float dA[ntx][mma_C::ne/2]; - - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_0 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_0); - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I; - - dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + k0/QI5_0]; - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - mma_B B; - float dB[mma_C::ne/2]; - - B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K); - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); - - dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma_K8(A[n], B); - -#pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l]; - } - } - } -#else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); - NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE -} - template static __device__ __forceinline__ void load_tiles_q5_1( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -694,8 +675,8 @@ template static __device__ __forceinlin qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q5_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; - x_qs[i*MMQ_MMA_TILE_X_K_Q5_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; #else x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; @@ -716,113 +697,19 @@ template static __device__ __forceinlin const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; #ifdef INT8_MMA_AVAILABLE - x_dm[i*MMQ_MMA_TILE_X_K_Q5_1 + kbxd] = bxi->dm; + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; #endif // INT8_MMA_AVAILABLE } } -template -static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + txs.qs; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl - (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE], - x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); - } - } -} - -template -static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { -#ifdef INT8_MMA_AVAILABLE - - typedef mma_int_A_I16K8 mma_A; - typedef mma_int_B_J8K8 mma_B; - typedef mma_int_C_I16J8 mma_C; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - - mma_A A[ntx]; - half2 dmA[ntx][mma_C::ne/2]; - - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_1 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_1); - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I; - - dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_1 + k0/QI5_1]; - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - mma_B B; - half2 dsB[mma_C::ne/2]; - - B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K); - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); - - dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma_K8(A[n], B); - -#pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2]; - sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); - } - } - } -#else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); - NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE -} - template static __device__ __forceinline__ void load_tiles_q8_0( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { #ifdef INT8_MMA_AVAILABLE int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_tile + WARP_SIZE); + float * x_df = (float *) (x_tile + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); int * x_qs = (int *) x_tile; @@ -843,18 +730,20 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x] = get_int_b2(bxi->qs, kqsx); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx); #else - x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b2(bxi->qs, kqsx); + x_qs[i*(2*WARP_SIZE + 1) + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx); + x_qs[i*(2*WARP_SIZE + 1) + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx); #endif // INT8_MMA_AVAILABLE } - const int blocks_per_tile_x_row = WARP_SIZE / QI8_0; + const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) { - int i = i0 + threadIdx.y * QI8_0 + threadIdx.x / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0/2) { + int i = i0 + threadIdx.y * (QI8_0/2) + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); @@ -863,16 +752,16 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; #ifdef INT8_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else - x_df[i*(WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d; + x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; #endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); const int * x_qs = (const int *) x; @@ -880,24 +769,29 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) { + const int k0 = k00 + k01; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl - (&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0], - y_df[j*MMQ_TILE_Y_K + k0/QI8_1]); +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl + (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % WARP_SIZE], + x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]); + } } } } template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K8 mma_A; @@ -911,49 +805,178 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + WARP_SIZE; + const float * x_df = (const float *) x_qs + 2*WARP_SIZE; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; - mma_A A[ntx]; - float dA[ntx][mma_C::ne/2]; + mma_A A[ntx][WARP_SIZE/QI8_0]; + float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0]; const int i0 = (threadIdx.y/ntx)*rows_per_warp; #pragma unroll for (int n = 0; n < ntx; ++n) { - A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + const int k0 = k00 + k01; + + A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); - dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + const int k0 = k00 + k01; + + dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; + } } } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - mma_B B; - float dB[mma_C::ne/2]; +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + const int k0 = k00 + k01; - B.load(y_qs + j0*MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K); + mma_B B; + float dB[mma_C::ne/2]; + + B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K); #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); - dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1]; + dB[l] = y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]; + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n][k01/QI8_0], B); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2]; + } + } + } + } +#else + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE +} + +template +static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl + (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { +#ifdef INT8_MMA_AVAILABLE + + typedef mma_int_A_I16K8 mma_A; + typedef mma_int_B_J8K8 mma_B; + typedef mma_int_C_I16J8 mma_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE; + const int * y_qs = (const int *) y + 4; + const half2 * y_dm = (const half2 *) y; + + mma_A A[ntx][WARP_SIZE/QI8_1]; + half2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1]; + + const int i0 = (threadIdx.y/ntx)*rows_per_warp; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + const int k0 = k00 + k01; + + A[n][k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); } #pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma_K8(A[n], B); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2]; + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + const int k0 = k00 + k01; + + dmA[n][l][k01/QI8_1] = x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]; + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + const int k0 = k00 + k01; + + mma_B B; + half2 dsB[mma_C::ne/2]; + + B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K); + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + dsB[l] = y_dm[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]; + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n][k01/QI8_1], B); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + const half2 dmA_dsB = dmA[n][l/2][k01/QI8_1]*dsB[l%2]; + sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); + } } } } @@ -968,44 +991,37 @@ template static __device__ __forceinlin #ifdef INT8_MMA_AVAILABLE int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + WARP_SIZE); + half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); #endif // INT8_MMA_AVAILABLE - const int kbx = threadIdx.x / QI2_K; const int kqsx = threadIdx.x % QI2_K; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_K) { + int i = i0 + threadIdx.y*(WARP_SIZE/QI2_K) + threadIdx.x/QI2_K; if (need_check) { i = min(i, i_max); } - const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx; + const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride; const int x_ql_0 = get_int_b2(bxi->qs, kqsx); #pragma unroll for (int l = 0; l < QR2_K; ++l) { - const int k = kbx*QI2_K + (kqsx/8)*8 + l*2 + (kqsx % 8)/4; + const int k = (kqsx/8)*32 + l*8 + kqsx % 8; - int x_qs_k = ((x_ql_0 >> (2*l)) & 0x03030303) << (2*(kqsx % 4)); - x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE); - x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 2, WARP_SIZE); - - if (kqsx % QR2_K != 0) { - continue; - } + const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303; #ifdef INT8_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; #else - x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k; + x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k; #endif // INT8_MMA_AVAILABLE } @@ -1018,44 +1034,68 @@ template static __device__ __forceinlin #endif // FAST_FP16_AVAILABLE #ifdef INT8_MMA_AVAILABLE - x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + threadIdx.x] = x_dm_ik; + x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik; #else - x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = x_dm_ik; + x_dm[i*(WARP_SIZE + 1) + kqsx] = x_dm_ik; #endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); const int * x_qs = (const int *) x; const half2 * x_dm = (const half2 *) x_qs + txs.qs; const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; + const half2 * y_ds = (const half2 *) y; + float2 y_df[mmq_x/nwarps]; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + y_df[j0/nwarps] = __half22float2(y_ds[j*MMQ_TILE_Y_K]); + } - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( - &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE], - &x_dm[i*(WARP_SIZE + 1) + k0], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]); +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + if (k01 < WARP_SIZE/2) { + constexpr int ns = 2; + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } else { + constexpr int ns = 1; + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } + } } } } template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K4 mma_A; + typedef mma_int_A_I16K8 mma_A_K8; typedef mma_int_B_J8K4 mma_B; typedef mma_int_C_I16J8 mma_C; @@ -1066,74 +1106,107 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE; + const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2; const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; + const half2 * y_ds = (const half2 *) y; const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - mma_A A[ntx][2]; - float dA[ntx][mma_C::ne/2][2]; - float mA[ntx][mma_C::ne/2][2]; + mma_A A[ntx][8]; + float dA[ntx][mma_C::ne/2][8]; + float mA[ntx][mma_C::ne/2][8]; #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + n*mma_A::I + mma_A::get_i(l); - const int shift = 2*mma_A::get_k(l); + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + const int k0 = k00 + k01; - A[n][0].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 0] >> shift) & 0x03030303; - A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 1] >> shift) & 0x03030303; + ((mma_A_K8 *) A[n])[k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); } + } +#pragma unroll + for (int n = 0; n < ntx; ++n) { #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); #pragma unroll - for (int kdm = 0; kdm < 2; ++kdm) { - const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + kdm]); + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) { + const int k0 = k00 + k01; - dA[n][l][kdm] = dm.x; - mA[n][l][kdm] = dm.y; + const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]); + + dA[n][l][k01/(QI8_1/2)] = dm.x; + mA[n][l][k01/(QI8_1/2)] = dm.y; } } } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - mma_B B[2]; - float dB[mma_C::ne/2]; - - B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + 0) % WARP_SIZE, MMQ_TILE_Y_K); - B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K); + float2 dB[mma_C::ne/2]; #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); - dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)]; + dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]); } - mma_C Cm[2]; - mma_A A1; - A1.x[0] = 0x01010101; - A1.x[1] = 0x01010101; - Cm[0].mma_K4(A1, B[0]); - Cm[1].mma_K4(A1, B[1]); +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + mma_B B[2]; + + B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); + + mma_C Cm[2]; + if (k01 >= WARP_SIZE * 3/4) { + mma_A A1; + A1.x[0] = 0x01010101; + A1.x[1] = 0x01010101; + Cm[0].mma_K4(A1, B[0]); + Cm[1].mma_K4(A1, B[1]); + } #pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C Cd[2]; + for (int n = 0; n < ntx; ++n) { + mma_C Cd[2]; - Cd[0].mma_K4(A[n][0], B[0]); - Cd[1].mma_K4(A[n][1], B[1]); + Cd[0].mma_K4(A[n][k01/4 + 0], B[0]); + Cd[1].mma_K4(A[n][k01/4 + 1], B[1]); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += ( - Cd[0].x[l]*dA[n][l/2][0] + Cd[1].x[l]*dA[n][l/2][1] - Cm[0].x[l]*mA[n][l/2][0] - Cm[1].x[l]*mA[n][l/2][1])*dB[l%2]; + for (int l = 0; l < mma_C::ne; ++l) { + float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1]; + if (k01 >= WARP_SIZE * 3/4) { + tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1]; + } + sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y); + } + } + } + +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) { + float2 sB[mma_C::ne/2]; + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x; + sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y; + } } } } @@ -1149,7 +1222,7 @@ template static __device__ __forceinlin #ifdef INT8_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); - int * x_sc = (int *) (x_df + WARP_SIZE/QI3_K); + int * x_sc = (int *) (x_df + 1); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); int * x_qs = (int *) x_tile; @@ -1157,75 +1230,66 @@ template static __device__ __forceinlin int * x_sc = (int *) (x_df + txs.dm); #endif // INT8_MMA_AVAILABLE - const int kbx = threadIdx.x / QI3_K; const int kqsx = threadIdx.x % QI3_K; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI3_K) { + int i = i0 + threadIdx.y * (WARP_SIZE/QI3_K) + threadIdx.x / QI3_K; if (need_check) { i = min(i, i_max); } - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx; + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; const int x_ql_0 = get_int_b2(bxi->qs, kqsx); const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2))); #pragma unroll for (int l = 0; l < QR3_K; ++l) { - const int k = kbx*(QR3_K*QI3_K) + (kqsx/8)*32 + l*8 + kqsx % 8; + const int k = (kqsx/8)*32 + l*8 + kqsx % 8; const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303; const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404; - int x_qs_k = (x_ql_k | x_qh_k) << (4*(k%2)); - x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE); - - if (kqsx % 2 != 0) { - continue; - } + const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404); #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2] = x_qs_k; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k; #else - x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k; + x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k; #endif // INT8_MMA_AVAILABLE } } - const int blocks_per_tile_x_row = WARP_SIZE / QI3_K; - const int kbxd = threadIdx.x % blocks_per_tile_x_row; - #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) { - int i = (i0 + threadIdx.y * QI3_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) { + int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbxd; + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; #ifdef INT8_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + kbxd] = bxi->d; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K] = bxi->d; #else - x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbxd] = bxi->d; + x_df[i] = bxi->d; #endif // INT8_MMA_AVAILABLE } #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) { + int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8); if (need_check) { i = min(i, i_max); } - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI3_K/4); + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; - const int ksc = threadIdx.x % (QI3_K/4); + const int ksc = threadIdx.x % (WARP_SIZE/8); const int ksc_low = ksc % (QI3_K/8); const int shift_low = 4 * (ksc / (QI3_K/8)); @@ -1238,16 +1302,16 @@ template static __device__ __forceinlin const int sc = __vsubss4(sc_low | sc_high, 0x20202020); #ifdef INT8_MMA_AVAILABLE - x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/4)] = sc; + x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/8)] = sc; #else - x_sc[i*(WARP_SIZE/4) + i/4 + threadIdx.x % (WARP_SIZE/4)] = sc; + x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc; #endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); const int * x_qs = (const int *) x; @@ -1256,32 +1320,35 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { + const int k0 = k00 + k01; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; - const int kbx = k0 / QI3_K; - const int ky = (k0 % QI3_K) * QR3_K; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; - const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; + const int8_t * scales = ((const int8_t *) (x_sc + i*(WARP_SIZE/8) + i/8)) + k0/4; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq( - &x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales, - x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]); + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq( + &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales, + x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + } } } } template static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K4 mma_A; + typedef mma_int_A_I16K8 mma_A_K8; typedef mma_int_B_J8K4 mma_B; typedef mma_int_C_I16J8 mma_C; @@ -1293,73 +1360,74 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma( const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + WARP_SIZE*2; - const int * x_sc = (const int *) x_df + WARP_SIZE/QI3_K; + const int * x_sc = (const int *) x_df + 1; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - mma_A A[ntx][2]; - int scA[ntx][mma_C::ne/2][2]; + mma_A A[ntx][8]; + int scA[ntx][mma_C::ne/2][8]; float dA[ntx][mma_C::ne/2]; #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + n*mma_A::I + mma_A::get_i(l); - const int k = QR3_K*k0 + mma_A::get_k(l); + for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { + const int k0 = k00 + k01; - A[n][0].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + 0] >> (4*(k%2))) & 0x0F0F0F0F; - A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F; - A[n][0].x[l] = __vsubss4(A[n][0].x[l], 0x04040404); - A[n][1].x[l] = __vsubss4(A[n][1].x[l], 0x04040404); + ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - const int kbx = k0 / QI3_K; - const int ky = (k0 % QI3_K) * QR3_K; - const int8_t * sc = ((const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q3_K + kbx*4)) + ky/4; +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) { + const int k0 = k00 + k01; - scA[n][l][0] = sc[0]; - scA[n][l][1] = sc[1]; - } + const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + k0/16]; + const int8_t * sc = (const int8_t *) &sc_packed; #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + for (int ksc = 0; ksc < sizeof(int); ++ksc) { + scA[n][l][k01/4 + ksc] = sc[ksc]; + } + } - dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/QI3_K]; + dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K]; } } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - mma_B B[2]; - float dB[mma_C::ne/2]; +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { + mma_B B[2]; + float dB[mma_C::ne/2]; - B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR3_K*k0 + 0) % WARP_SIZE, MMQ_TILE_Y_K); - B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR3_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K); + B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); - dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)]; - } + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } #pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C C[2]; - C[0].mma_K4(A[n][0], B[0]); - C[1].mma_K4(A[n][1], B[1]); + for (int n = 0; n < ntx; ++n) { + mma_C C[2]; + C[0].mma_K4(A[n][k01/4 + 0], B[0]); + C[1].mma_K4(A[n][k01/4 + 1], B[1]); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += (C[0].x[l]*scA[n][l/2][0] + C[1].x[l]*scA[n][l/2][1])*dA[n][l/2]*dB[l%2]; + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]* + (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1]); + } } } } @@ -1451,7 +1519,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); const int * x_qs = (const int *) x; @@ -1460,26 +1528,31 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) { + const int k0 = k00 + k01; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8); +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq( - &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8, - x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + ((QR4_K*k0) % WARP_SIZE)/QI8_1]); + const uint8_t * sc = (const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/32] + 2*(k01/16); + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq( + &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, + x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } } } } template static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K8 mma_A; @@ -1500,35 +1573,40 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - mma_A A[ntx][2]; - int scA[ntx][mma_C::ne/2][2]; - int mA[ntx][mma_C::ne/2][2]; + mma_A A[ntx][4]; + int scA[ntx][mma_C::ne/2][4]; + int mA[ntx][mma_C::ne/2][4]; half2 dmA[ntx][mma_C::ne/2]; #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 8) { - A[n][kvdr/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_K + k0, MMQ_MMA_TILE_X_K_Q4_K); + for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) { + const int k0 = k00 + k01; + + A[n][k01/8 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_K + k0/QR4_K, MMQ_MMA_TILE_X_K_Q4_K); #pragma unroll for (int l = 0; l < mma_A::ne; ++l) { - A[n][kvdr/4 + 1].x[l] = (A[n][kvdr/4 + 0].x[l] >> 4) & 0x0F0F0F0F; - A[n][kvdr/4 + 0].x[l] &= 0x0F0F0F0F; + A[n][k01/8 + 1].x[l] = (A[n][k01/8 + 0].x[l] >> 4) & 0x0F0F0F0F; + A[n][k01/8 + 0].x[l] &= 0x0F0F0F0F; } } #pragma unroll - for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) { + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + + const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + (k00/32 + 0)]; + const int m_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + (k00/32 + 2)]; + + const uint8_t * sc = (const uint8_t *) &sc_packed; + const uint8_t * m = (const uint8_t *) &m_packed; + #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); - - const uint8_t * sc = ((const uint8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + k0/16]) + 2 * ((k0 % 16) / 8); - const uint8_t * m = sc + 8; - - scA[n][l][kvdr/4] = sc[kvdr/4]; - mA[n][l][kvdr/4] = m[kvdr/4]; + for (int ksc = 0; ksc < sizeof(int); ++ksc) { + scA[n][l][ksc] = sc[ksc]; + mA[n][l][ksc] = m[ksc]; } } @@ -1536,7 +1614,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( for (int l = 0; l < mma_C::ne/2; ++l) { const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); - dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K + k0/QI4_K]; + dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K]; } } @@ -1546,28 +1624,28 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( float tmpm[ntx][mma_C::ne] = {{0.0f}}; #pragma unroll - for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) { + for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { mma_B B; half2 dsB[mma_C::ne/2]; - B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0 + 2*kvdr) % WARP_SIZE, MMQ_TILE_Y_K); + B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); - dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)]; + dsB[l] = y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]; } #pragma unroll for (int n = 0; n < ntx; ++n) { mma_C C; - C.mma_K8(A[n][kvdr/4], B); + C.mma_K8(A[n][k01/8], B); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { - tmpd[n][l] += (C.x[l]*scA[n][l/2][kvdr/4]) * __low2float(dsB[l%2]); - tmpm[n][l] += mA[n][l/2][kvdr/4] * __high2float(dsB[l%2]); + tmpd[n][l] += (C.x[l]*scA[n][l/2][k01/8]) * __low2float(dsB[l%2]); + tmpm[n][l] += mA[n][l/2][k01/8] * __high2float(dsB[l%2]); } } } @@ -1682,7 +1760,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); const int * x_qs = (const int *) x; @@ -1691,26 +1769,31 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) { + const int k0 = k00 + k01; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8); +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq( - &x_qs[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8, - x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + ((QR5_K*k0) % WARP_SIZE)/QI8_1]); + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16); + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq( + &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, + x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } } } } template static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K8 mma_A; @@ -1731,26 +1814,34 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - mma_A A[ntx][2]; - int scA[ntx][mma_C::ne/2][2]; - int mA[ntx][mma_C::ne/2][2]; + mma_A A[ntx][4]; + int scA[ntx][mma_C::ne/2][4]; + int mA[ntx][mma_C::ne/2][4]; half2 dmA[ntx][mma_C::ne/2]; #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) { - A[n][kvdr/4].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_K + (QR5_K*k0 + QR5_K*kvdr), MMQ_MMA_TILE_X_K_Q5_K); + for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { + const int k0 = k00 + k01; + + A[n][k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_K + k0, MMQ_MMA_TILE_X_K_Q5_K); + } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - const uint8_t * sc = ((const uint8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + k0/16]) + 2 * ((k0 % 16) / 8); - const uint8_t * m = sc + 8; + const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + (k00/32 + 0)]; + const int m_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + (k00/32 + 2)]; - scA[n][l][kvdr/4] = sc[kvdr/4]; - mA[n][l][kvdr/4] = m[kvdr/4]; + const uint8_t * sc = (const uint8_t *) &sc_packed; + const uint8_t * m = (const uint8_t *) &m_packed; + +#pragma unroll + for (int ksc = 0; ksc < sizeof(int); ++ksc) { + scA[n][l][ksc] = sc[ksc]; + mA[n][l][ksc] = m[ksc]; } } @@ -1758,7 +1849,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( for (int l = 0; l < mma_C::ne/2; ++l) { const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K + k0/QI5_K]; + dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K]; } } @@ -1768,28 +1859,30 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( float tmpm[ntx][mma_C::ne] = {{0.0f}}; #pragma unroll - for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) { + for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { + const int k0 = k00 + k01; + mma_B B; half2 dsB[mma_C::ne/2]; - B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0 + 2*kvdr) % WARP_SIZE, MMQ_TILE_Y_K); + B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K); #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); - dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)]; + dsB[l] = y_ds[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]; } #pragma unroll for (int n = 0; n < ntx; ++n) { mma_C C; - C.mma_K8(A[n][kvdr/4], B); + C.mma_K8(A[n][k01/8], B); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { - tmpd[n][l] += (C.x[l]*scA[n][l/2][kvdr/4]) * __low2float(dsB[l%2]); - tmpm[n][l] += mA[n][l/2][kvdr/4] * __high2float(dsB[l%2]); + tmpd[n][l] += (C.x[l]*scA[n][l/2][k01/8]) * __low2float(dsB[l%2]); + tmpm[n][l] += mA[n][l/2][k01/8] * __high2float(dsB[l%2]); } } } @@ -1896,7 +1989,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); const int * x_qs = (const int *) x; @@ -1905,26 +1998,31 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) { + const int k0 = k00 + k01; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; - const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]); +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq( - &x_qs[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc, - x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]); + const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]); + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq( + &x_qs[i*(QR6_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, + x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + } } } } template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K4 mma_A; @@ -1945,25 +2043,35 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - mma_A A[ntx][4]; - int scA[ntx][mma_C::ne/2][4]; + mma_A A[ntx][8]; + int scA[ntx][mma_C::ne/2][8]; float dA[ntx][mma_C::ne/2]; #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) { - A[n][kvdr/2 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (QR6_K*k0 + QR6_K*kvdr + 0), MMQ_MMA_TILE_X_K_Q6_K); - A[n][kvdr/2 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (QR6_K*k0 + QR6_K*kvdr + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K); + for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { + const int k0 = k00 + k01; + + A[n][k01/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); + A[n][k01/4 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K); + } + +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) { + const int k0 = k00 + k01; #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - const int8_t * sc = ((const int8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/8]); + const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16]; + const int8_t * sc = (const int8_t *) &sc_packed; - scA[n][l][kvdr/2 + 0] = sc[kvdr/2 + 0]; - scA[n][l][kvdr/2 + 1] = sc[kvdr/2 + 1]; +#pragma unroll + for (int ksc = 0; ksc < sizeof(int); ++ksc) { + scA[n][l][k01/4 + ksc] = sc[ksc]; + } } } @@ -1971,7 +2079,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( for (int l = 0; l < mma_C::ne/2; ++l) { const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K + k0/QI6_K]; + dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K]; } } @@ -1980,30 +2088,29 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( float tmp[ntx][mma_C::ne] = {{0.0f}}; #pragma unroll - for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) { + for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { mma_B B[2]; float dB[mma_C::ne/2]; - const int k0B = (2*k0 + 2*kvdr) % WARP_SIZE; - B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0 + k0B, MMQ_TILE_Y_K); - B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k0B, MMQ_TILE_Y_K); + B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K); + B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K); #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); - dB[l] = y_df[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)]; + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; } #pragma unroll for (int n = 0; n < ntx; ++n) { mma_C C[2]; - C[0].mma_K4(A[n][kvdr/2 + 0], B[0]); - C[1].mma_K4(A[n][kvdr/2 + 1], B[1]); + C[0].mma_K4(A[n][k01/4 + 0], B[0]); + C[1].mma_K4(A[n][k01/4 + 1], B[1]); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { - tmp[n][l] += (C[0].x[l]*scA[n][l/2][kvdr/2 + 0] + C[1].x[l]*scA[n][l/2][kvdr/2 + 1])*dB[l%2]; + tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2]; } } } @@ -2051,8 +2158,8 @@ template static __device__ __forceinlin const int2 v = get_int_from_table_16(aux_q4); const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 0] = v.x; - x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 4] = v.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; #else x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; @@ -2073,7 +2180,7 @@ template static __device__ __forceinlin const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; #ifdef INT8_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + kbxd] = __half2float(bxi->d); + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); #else x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d); #endif // INT8_MMA_AVAILABLE @@ -2109,8 +2216,8 @@ template static __device__ __forceinlin const int2 v = get_int_from_table_16(aux_q4); const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 0] = v.x; - x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 4] = v.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; #else x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; @@ -2133,7 +2240,7 @@ template static __device__ __forceinlin | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4); #ifdef INT8_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + threadIdx.x % 8] = d * (ls - 32); + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); #else x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); #endif // INT8_MMA_AVAILABLE @@ -2229,16 +2336,16 @@ template struct mmq_type_traits { static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; template struct mmq_type_traits { static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_1_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_1_q8_1_dp4a; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; }; template @@ -2293,45 +2400,18 @@ template struct mmq_type_traits { static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; template struct mmq_type_traits { static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -static bool mmq_need_sum(const ggml_type type_x) { - switch (type_x) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - return true; - case GGML_TYPE_Q5_0: - return false; - case GGML_TYPE_Q5_1: - return true; - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - return false; - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - return true; - case GGML_TYPE_Q6_K: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ4_NL: - return false; - default: - GGML_ASSERT(false); - break; - } - return false; -} - template static __device__ void mul_mat_q_process_tile( const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup, @@ -2339,10 +2419,7 @@ static __device__ void mul_mat_q_process_tile( const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) { constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int qr = ggml_cuda_type_traits::qr; - constexpr int qi = ggml_cuda_type_traits::qi; constexpr int mmq_y = get_mmq_y_device(); - constexpr int vdr = mmq_type_traits::vdr; constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; extern __shared__ char data_mul_mat_q[]; @@ -2357,7 +2434,7 @@ static __device__ void mul_mat_q_process_tile( constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; #endif // INT8_MMA_AVAILABLE - constexpr int blocks_per_warp = WARP_SIZE / qi; + constexpr int blocks_per_iter = MMQ_ITER_K / qk; float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; @@ -2366,29 +2443,40 @@ static __device__ void mul_mat_q_process_tile( const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int)); - for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_warp) { - + for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01); -#pragma unroll - for (int kr = 0; kr < qr; ++kr) { - const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + kr*sizeof(block_q8_1_mmq)/sizeof(int)); + { + const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int)); #pragma unroll for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) { int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x; tile_y[l] = by0[l]; } - - __syncthreads(); - -// #pragma unroll // unrolling this loop causes too much register pressure - for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) { - vec_dot(tile_x, tile_y, sum, k0); - } - - __syncthreads(); } + + __syncthreads(); + + vec_dot(tile_x, tile_y, sum, 0); + + __syncthreads(); + + { + const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int)); +#pragma unroll + for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) { + int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x; + + tile_y[l] = by0[l]; + } + } + + __syncthreads(); + + vec_dot(tile_x, tile_y, sum, WARP_SIZE); + + __syncthreads(); } if (fixup) { @@ -2424,7 +2512,6 @@ static __global__ void mul_mat_q( } constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int qi = ggml_cuda_type_traits::qi; constexpr int mmq_y = get_mmq_y_device(); // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: @@ -2439,7 +2526,7 @@ static __global__ void mul_mat_q( #endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA const int64_t blocks_per_ne00 = ne00 / qk; - constexpr int blocks_per_warp = WARP_SIZE / qi; + constexpr int blocks_per_iter = MMQ_ITER_K / qk; const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y @@ -2448,8 +2535,8 @@ static __global__ void mul_mat_q( int64_t kbc = (int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x; int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x; - kbc -= (kbc % blocks_per_ne00) % blocks_per_warp; - kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_warp; + kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; + kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter; // kb0 == k index when doing the matrix multiplication for an output tile. int kb0_start = kbc % blocks_per_ne00; @@ -2490,8 +2577,7 @@ static __global__ void mul_mat_q_stream_k_fixup( constexpr int mmq_y = get_mmq_y_device(); constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int qi = ggml_cuda_type_traits::qi; - constexpr int blocks_per_warp = WARP_SIZE / qi; + constexpr int blocks_per_iter = MMQ_ITER_K / qk; const int64_t blocks_per_ne00 = ne00 / qk; float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; @@ -2501,15 +2587,18 @@ static __global__ void mul_mat_q_stream_k_fixup( bool any_fixup = false; - const int bidx_start = (blockIdx.y*nty + blockIdx.x) * block_num_mmq / (gridDim.y*gridDim.x); - const int bidx_stop = (blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq / (gridDim.y*gridDim.x) + 1; + const int bidx_start = ((blockIdx.y*nty + blockIdx.x) * block_num_mmq) / (gridDim.y*gridDim.x); + const int bidx_stop = ((blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq + gridDim.y*gridDim.x - 1) / (gridDim.y*gridDim.x); + + int64_t kbc_0; + int64_t kbc_stop_0 = (int64_t) bidx_start*blocks_per_ne00*ntx*nty / block_num_mmq; for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) { - int64_t kbc = (int64_t) bidx *blocks_per_ne00*ntx*nty / block_num_mmq; - int64_t kbc_stop = (int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq; + kbc_0 = kbc_stop_0; + kbc_stop_0 = (int64_t) (bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq; - kbc -= (kbc % blocks_per_ne00) % blocks_per_warp; - kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_warp; + const int64_t kbc = kbc_0 - (kbc_0 % blocks_per_ne00) % blocks_per_iter; + const int64_t kbc_stop = kbc_stop_0 - (kbc_stop_0 % blocks_per_ne00) % blocks_per_iter; // Skip fixup tile if the MMQ CUDA block never wrote anything to it: if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) { diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index b46786822..aa7f1eff0 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -37,47 +37,92 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest reinterpret_cast(y[ib].ds.y) = sum; } -template +template static __global__ void quantize_mmq_q8_1( const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) { - const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; + constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32; + constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32; + + const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4; if (ix0 >= kx0_padded) { return; } + const float4 * x4 = (const float4 *) x; + const int64_t ix1 = kx1*blockIdx.z + blockIdx.y; block_q8_1_mmq * y = (block_q8_1_mmq *) vy; - const int64_t ib0 = blockIdx.z*(gridDim.y*gridDim.x*blockDim.x/(4*QK8_1)); // first block of channel - const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel - const int64_t iqs = ix0 % (4*QK8_1); // quant index in block + const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel + const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel + const int64_t iqs = ix0 % (4*QK8_1); // quant index in block - const float xi = ix0 < kx0 ? x[ix1*kx0 + ix0] : 0.0f; - float amax = fabsf(xi); + // Load 4 floats per thread and calculate max. abs. value between them: + const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f); + float amax = fabsf(xi.x); + amax = fmaxf(amax, fabsf(xi.y)); + amax = fmaxf(amax, fabsf(xi.z)); + amax = fmaxf(amax, fabsf(xi.w)); - amax = warp_reduce_max(amax); - - float sum; - if (need_sum) { - sum = warp_reduce_sum(xi); + // Exchange max. abs. value between vals_per_scale/4 threads. +#pragma unroll + for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE)); } - const float d = amax / 127; - const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); + float sum; + if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) { + sum = xi.x + xi.y + xi.z + xi.w; - y[ib].qs[iqs] = q; + // Exchange calculate sum across vals_per_sum/4 threads. +#pragma unroll + for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) { + sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE); + } + } + + const float d_inv = 127.0f / amax; + char4 q; + q.x = roundf(xi.x*d_inv); + q.y = roundf(xi.y*d_inv); + q.z = roundf(xi.z*d_inv); + q.w = roundf(xi.w*d_inv); + + // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth: + char4 * yqs4 = (char4 *) y[ib].qs; + yqs4[iqs/4] = q; + + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) { + if (iqs % 16 != 0 || iqs >= 96) { + return; + } + + y[ib].d2s6[2 + iqs/16] = sum; + + if (iqs % 64 != 0) { + return; + } + + const float d = 1.0f / d_inv; + + y[ib].d2s6[iqs/64] = d; - if (iqs % QK8_1 != 0) { return; } - if (need_sum) { - y[ib].ds[iqs/QK8_1] = make_half2(d, sum); + if (iqs % 32 != 0) { + return; + } + + const float d = 1.0f / d_inv; + + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) { + y[ib].ds4[iqs/32] = make_half2(d, sum); } else { - ((float *) y[ib].ds)[iqs/QK8_1] = d; + y[ib].d4[iqs/32] = d; } } @@ -101,12 +146,24 @@ void quantize_mmq_q8_1_cuda( GGML_ASSERT(kx0_padded % (4*QK8_1) == 0); - const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; + const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); const dim3 num_blocks(block_num_x, kx1, channels); - const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); - if (mmq_need_sum(type_x)) { - quantize_mmq_q8_1<<>>(x, vy, kx0, kx1, kx0_padded); - } else { - quantize_mmq_q8_1<<>>(x, vy, kx0, kx1, kx0_padded); + const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1); + switch (mmq_get_q8_1_ds_layout(type_x)) { + case MMQ_Q8_1_DS_LAYOUT_D4: + quantize_mmq_q8_1 + <<>>(x, vy, kx0, kx1, kx0_padded); + break; + case MMQ_Q8_1_DS_LAYOUT_DS4: + quantize_mmq_q8_1 + <<>>(x, vy, kx0, kx1, kx0_padded); + break; + case MMQ_Q8_1_DS_LAYOUT_D2S6: + quantize_mmq_q8_1 + <<>>(x, vy, kx0, kx1, kx0_padded); + break; + default: + GGML_ASSERT(false); + break; } } diff --git a/ggml/src/ggml-cuda/quantize.cuh b/ggml/src/ggml-cuda/quantize.cuh index 486c9360a..03bf322b9 100644 --- a/ggml/src/ggml-cuda/quantize.cuh +++ b/ggml/src/ggml-cuda/quantize.cuh @@ -5,7 +5,11 @@ #include -#define CUDA_QUANTIZE_BLOCK_SIZE 256 +#define CUDA_QUANTIZE_BLOCK_SIZE 256 +#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128 + +static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk of out-of-bounds access."); +static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access."); typedef void (*quantize_cuda_t)( const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 1d510484a..6a17d0f3e 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -189,7 +189,7 @@ template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp } #define VDR_Q2_K_Q8_1_MMVQ 1 -#define VDR_Q2_K_Q8_1_MMQ 2 +#define VDR_Q2_K_Q8_1_MMQ 4 // contiguous v/x values static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( @@ -219,32 +219,56 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( return dm2f.x*sumf_d - dm2f.y*sumf_m; } -// contiguous u/y values +// contiguous v/x + u/y values +template static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( - const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8) { + const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8, const half2 * s8) { - float sumf_d = 0.0f; - float sumf_m = 0.0f; + float sumf = 0.0f; + float sumf_d8 = 0.0f; #pragma unroll - for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) { - const float2 dm2f = __half22float2(dm2[i0/(QI8_1/2)]); - int sumi_d = 0; - int sumi_m = 0; + for (int i0 = 0; i0 < QR2_K*VDR_Q2_K_Q8_1_MMQ; i0 += QI8_1) { + const float2 dm2f0 = __half22float2(dm2[i0/(QI8_1/2) + 0]); + int sumi_d0 = 0; + + const float2 dm2f1 = __half22float2(dm2[i0/(QI8_1/2) + 1]); + int sumi_d1 = 0; - const int vi0 = v[i0/(QI8_1/2)]; #pragma unroll for (int i = i0; i < i0 + QI8_1/2; ++i) { - const int vi = (vi0 >> (2*(i % (QI8_1/2)))) & 0x03030303; - sumi_d = ggml_cuda_dp4a(vi, u[i], sumi_d); // SIMD dot product - sumi_m = ggml_cuda_dp4a(0x01010101, u[i], sumi_m); + sumi_d0 = ggml_cuda_dp4a(v[i], u[i], sumi_d0); } + sumf_d8 += dm2f0.x * sumi_d0; - sumf_d += dm2f.x * sumi_d; - sumf_m += dm2f.y * sumi_m; +#pragma unroll + for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) { + sumi_d1 = ggml_cuda_dp4a(v[i], u[i], sumi_d1); + } + sumf_d8 += dm2f1.x * sumi_d1; + + if (i0/QI8_1 < ns8) { + const float2 s8f = __half22float2(s8[i0/QI8_1]); + sumf -= dm2f0.y*s8f.x; + sumf -= dm2f1.y*s8f.y; + } else { + int sumi_m0 = 0; +#pragma unroll + for (int i = i0; i < i0 + QI8_1/2; ++i) { + sumi_m0 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m0); + } + sumf_d8 -= dm2f0.y * sumi_m0; + + int sumi_m1 = 0; +#pragma unroll + for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) { + sumi_m1 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m1); + } + sumf_d8 -= dm2f1.y * sumi_m1; + } } - return d8*(sumf_d - sumf_m); + return sumf + d8*sumf_d8; } #define VDR_Q3_K_Q8_1_MMVQ 1 @@ -283,7 +307,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( return d3 * sumf; } -// contiguous u/y values +// contiguous v/x + u/y values static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales, const float & d3, const float & d8) { @@ -296,8 +320,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( #pragma unroll for (int i = i0; i < i0 + QI8_1/2; ++i) { - const int vi = __vsubss4((v[i/2] >> (4*(i%2))) & 0x0F0F0F0F, 0x04040404); - sumi_sc = ggml_cuda_dp4a(vi, u[i], sumi_sc); // SIMD dot product + sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product } sumi += sumi_sc * scales[i0 / (QI8_1/2)]; @@ -334,7 +357,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( return dm4f.x*sumf_d - dm4f.y*sumf_m; } -// contiguous u/y values +// contiguous v/x + u/y values static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { @@ -397,7 +420,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( return dm5f.x*sumf_d - dm5f.y*sumf_m; } -// contiguous u/y values +// contiguous v/x + u/y values static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { @@ -451,13 +474,16 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( return d*sumf; } -// contiguous u/y values +// contiguous v/x + u/y values static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc, const float & d6, const float * __restrict__ d8) { float sumf_d = 0.0f; + const int sc_packed = get_int_b4(sc, 0); + const int8_t * sc_reg = (const int8_t *) &sc_packed; + #pragma unroll for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) { int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale @@ -471,7 +497,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product } - sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y); + sumf_d += d8[i0/4] * (sc_reg[i0/2+0]*sumi_d.x + sc_reg[i0/2+1]*sumi_d.y); } return d6 * sumf_d; From b078c619aa4e97fe726d61b0b5499a2e19a418a4 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 11 Jul 2024 17:53:42 +0200 Subject: [PATCH 02/13] cuda : suppress 'noreturn' warn in no_device_code (#8414) * cuda : suppress 'noreturn' warn in no_device_code This commit adds a while(true) loop to the no_device_code function in common.cuh. This is done to suppress the warning: ```console /ggml/src/ggml-cuda/template-instances/../common.cuh:346:1: warning: function declared 'noreturn' should not return [-Winvalid-noreturn] 346 | } | ^ ``` The motivation for this is to reduce the number of warnings when compilng with GGML_HIPBLAS=ON. Signed-off-by: Daniel Bevenius * squash! cuda : suppress 'noreturn' warn in no_device_code Update __trap macro instead of using a while loop to suppress the warning. Signed-off-by: Daniel Bevenius --------- Signed-off-by: Daniel Bevenius --- ggml/src/ggml-cuda/common.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 4ff06b871..26d9412a2 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -104,7 +104,7 @@ #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) #define cudaStream_t hipStream_t #define cudaSuccess hipSuccess -#define __trap abort +#define __trap() do { abort(); __builtin_unreachable(); } while(0) #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED #define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED From 368645698ab648e390dcd7c00a2bf60efa654f57 Mon Sep 17 00:00:00 2001 From: Nicholai Tukanov Date: Thu, 11 Jul 2024 11:49:15 -0500 Subject: [PATCH 03/13] ggml : add NVPL BLAS support (#8329) (#8425) * ggml : add NVPL BLAS support * ggml : replace `_ENABLE_CBLAS` with `GGML_BLAS_USE_` --------- Co-authored-by: ntukanov --- Makefile | 8 +++++++- ggml/src/ggml-blas.cpp | 13 +++++++++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index 668b38b99..4869d2ecf 100644 --- a/Makefile +++ b/Makefile @@ -547,11 +547,17 @@ ifdef GGML_OPENBLAS64 endif # GGML_OPENBLAS64 ifdef GGML_BLIS - MK_CPPFLAGS += -DGGML_USE_BLAS -I/usr/local/include/blis -I/usr/include/blis + MK_CPPFLAGS += -DGGML_USE_BLAS -DGGML_BLAS_USE_BLIS -I/usr/local/include/blis -I/usr/include/blis MK_LDFLAGS += -lblis -L/usr/local/lib OBJ_GGML += ggml/src/ggml-blas.o endif # GGML_BLIS +ifdef GGML_NVPL + MK_CPPFLAGS += -DGGML_USE_BLAS -DGGML_BLAS_USE_NVPL -DNVPL_ILP64 -I/usr/local/include/nvpl_blas -I/usr/include/nvpl_blas + MK_LDFLAGS += -L/usr/local/lib -lnvpl_blas_core -lnvpl_blas_ilp64_gomp + OBJ_GGML += ggml/src/ggml-blas.o +endif # GGML_NVPL + ifndef GGML_NO_LLAMAFILE MK_CPPFLAGS += -DGGML_USE_LLAMAFILE OBJ_GGML += ggml/src/llamafile/sgemm.o diff --git a/ggml/src/ggml-blas.cpp b/ggml/src/ggml-blas.cpp index d709a357b..a37aa4072 100644 --- a/ggml/src/ggml-blas.cpp +++ b/ggml/src/ggml-blas.cpp @@ -8,11 +8,12 @@ # include #elif defined(GGML_BLAS_USE_MKL) # include +#elif defined(GGML_BLAS_USE_BLIS) +# include +#elif defined(GGML_BLAS_USE_NVPL) +# include #else # include -# ifdef BLIS_ENABLE_CBLAS -# include -# endif #endif struct ggml_backend_blas_context { @@ -140,10 +141,14 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg openblas_set_num_threads(ctx->n_threads); #endif -#if defined(BLIS_ENABLE_CBLAS) +#if defined(GGML_BLAS_USE_BLIS) bli_thread_set_num_threads(ctx->n_threads); #endif +#if defined(GGML_BLAS_USE_NVPL) + nvpl_blas_set_num_threads(ctx->n_threads); +#endif + for (int64_t i13 = 0; i13 < ne13; i13++) { for (int64_t i12 = 0; i12 < ne12; i12++) { const int64_t i03 = i13/r3; From b549a1bbefb2f1fbb8b558bac1f2ae7967e60964 Mon Sep 17 00:00:00 2001 From: Chen Xi Date: Fri, 12 Jul 2024 00:52:04 +0000 Subject: [PATCH 04/13] [SYCL] fix the mul_mat_id ut issues (#8427) * fix part of mul_mat_id * skip the bfloat 16 sycl ut Signed-off-by: Chen Xi --------- Signed-off-by: Chen Xi Co-authored-by: Meng, Hengyu Co-authored-by: Chen Xi --- ggml/src/ggml-backend.c | 2 +- ggml/src/ggml-sycl.cpp | 49 +++++++++++------------------------------ src/llama.cpp | 7 ------ 3 files changed, 14 insertions(+), 44 deletions(-) diff --git a/ggml/src/ggml-backend.c b/ggml/src/ggml-backend.c index 13c71c310..dbbaa3941 100644 --- a/ggml/src/ggml-backend.c +++ b/ggml/src/ggml-backend.c @@ -394,7 +394,7 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) // backend registry -#define GGML_REG_MAX_BACKENDS 16 +#define GGML_REG_MAX_BACKENDS 64 struct ggml_backend_reg { char name[128]; diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp index 9c419ba89..5a890237f 100644 --- a/ggml/src/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl.cpp @@ -3768,37 +3768,13 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids)))); SYCL_CHECK(CHECK_TRY_ERROR(stream->wait())); - const ggml_tensor_extra_gpu *src0_extra = - (const ggml_tensor_extra_gpu *)src0->extra; - const ggml_tensor_extra_gpu *src1_extra = - (const ggml_tensor_extra_gpu *)src1->extra; - const ggml_tensor_extra_gpu *dst_extra = - (const ggml_tensor_extra_gpu *)dst->extra; - - ggml_tensor_extra_gpu src0_row_extra; - ggml_tensor_extra_gpu src1_row_extra; - ggml_tensor_extra_gpu dst_row_extra; - ggml_tensor src0_row = *src0; ggml_tensor src1_row = *src1; ggml_tensor dst_row = *dst; - src1_row.backend = GGML_BACKEND_TYPE_GPU; - dst_row.backend = GGML_BACKEND_TYPE_GPU; - - src0_row.extra = &src0_row_extra; - src1_row.extra = &src1_row_extra; - dst_row.extra = &dst_row_extra; - - char *src0_original = src1->backend == GGML_BACKEND_TYPE_CPU - ? (char *)src0->data - : (char *)src0_extra->data_device[ctx.device]; - char *src1_original = src1->backend == GGML_BACKEND_TYPE_CPU - ? (char *)src1->data - : (char *)src1_extra->data_device[ctx.device]; - char *dst_original = dst->backend == GGML_BACKEND_TYPE_CPU - ? (char *)dst->data - : (char *)dst_extra->data_device[ctx.device]; + char *src0_original = (char *)src0->data; + char *src1_original = (char *)src1->data; + char *dst_original = (char *)dst->data; src0_row.ne[2] = 1; src0_row.ne[3] = 1; @@ -3827,12 +3803,9 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten const int64_t i1 = id; const int64_t i2 = i12; - src0_row_extra.data_device[ctx.device] = - src0_original + i02*nb02; - src1_row_extra.data_device[ctx.device] = - src1_original + + i11*nb11 + i12*nb12; - dst_row_extra.data_device[ctx.device] = - dst_original + i1*nb1 + i2*nb2; + src0_row.data = src0_original + i02*nb02; + src1_row.data = src1_original + + i11*nb11 + i12*nb12; + dst_row.data = dst_original + i1*nb1 + i2*nb2; ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row); } @@ -3841,8 +3814,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten ggml_sycl_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); ggml_sycl_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); - src1_row_extra.data_device[ctx.device] = src1_contiguous.get(); - dst_row_extra.data_device[ctx.device] = dst_contiguous.get(); + src1_row.data = src1_contiguous.get(); + dst_row.data = dst_contiguous.get(); for (int64_t i02 = 0; i02 < n_as; i02++) { int64_t num_src1_rows = 0; @@ -3898,7 +3871,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten }); } - src0_row_extra.data_device[ctx.device] = src0_original + i02*nb02; + src0_row.data = src0_original + i02*nb02; GGML_ASSERT(nb11 == sizeof(float)*ne10); GGML_ASSERT(nb1 == sizeof(float)*ne0); @@ -5221,6 +5194,10 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons return false; } } + ggml_type src0_type = op->src[0]->type; + if (src0_type == GGML_TYPE_BF16) { + return false; + } return true; } break; case GGML_OP_GET_ROWS: diff --git a/src/llama.cpp b/src/llama.cpp index ed77ed918..f91ac7779 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5883,13 +5883,6 @@ static bool llm_load_tensors( auto & hparams = model.hparams; -#ifdef GGML_USE_SYCL - // disable MoE with SYCL until mul_mat_id is updated - if (hparams.n_expert > 0) { - n_gpu_layers = 0; - } -#endif - model.split_mode = split_mode; model.main_gpu = main_gpu; model.n_gpu_layers = n_gpu_layers; From 370b1f7e7a7514d9a63daf15f7b0f00319b7f908 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 12 Jul 2024 10:46:02 +0300 Subject: [PATCH 05/13] ggml : minor naming changes (#8433) * ggml : minor naming changes ggml-ci * ggml : use PRId64 [no ci] * ggml : revert FA K/Q names --- examples/quantize-stats/quantize-stats.cpp | 2 +- ggml/include/ggml.h | 50 +++++----- ggml/src/ggml-aarch64.c | 68 +++++++------ ggml/src/ggml-aarch64.h | 14 +-- ggml/src/ggml-quants.c | 94 ++++++++--------- ggml/src/ggml-quants.h | 34 +++---- ggml/src/ggml.c | 111 ++++++++++----------- tests/test-double-float.cpp | 4 +- tests/test-quantize-fns.cpp | 2 +- tests/test-quantize-perf.cpp | 2 +- 10 files changed, 192 insertions(+), 189 deletions(-) diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 746df8446..68cf8d359 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -154,7 +154,7 @@ static void test_roundtrip_on_chunk( } if (use_reference) { - qfns.from_float_reference(input_scratch, quantized_scratch, chunk_size); + qfns.from_float_ref(input_scratch, quantized_scratch, chunk_size); } else { qfns.from_float(input_scratch, quantized_scratch, chunk_size); } diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 1e3677537..f2145ff35 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -714,9 +714,9 @@ extern "C" { GGML_API GGML_CALL size_t ggml_nbytes (const struct ggml_tensor * tensor); GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN - GGML_API GGML_CALL int ggml_blck_size(enum ggml_type type); - GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block - GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row + GGML_API GGML_CALL int64_t ggml_blck_size(enum ggml_type type); + GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block + GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row GGML_DEPRECATED( GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float @@ -2410,31 +2410,31 @@ extern "C" { #endif typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); - typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx, - const void * GGML_RESTRICT y, size_t by, int nrc); - typedef void (*ggml_from_float_to_mat_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr, - int64_t k, int64_t bx); - typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, - const void * GGML_RESTRICT y, int nr, int nc); - typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, - const void * GGML_RESTRICT y, int nr, int nc); + typedef void (*ggml_from_float_to_mat_t) + (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr, int64_t k, int64_t bs); + typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx, + const void * GGML_RESTRICT y, size_t by, int nrc); + typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, + const void * GGML_RESTRICT y, int nr, int nc); + typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, + const void * GGML_RESTRICT y, int nr, int nc); typedef struct { - const char * type_name; - int blck_size; - size_t type_size; - bool is_quantized; - ggml_to_float_t to_float; - ggml_from_float_t from_float; - ggml_from_float_t from_float_reference; - ggml_vec_dot_t vec_dot; - enum ggml_type vec_dot_type; - int64_t nrows; // number of rows to process simultaneously; - int64_t ncols; // number of columns to process simultaneously; - int64_t interleave_blcksize; // interleave elements in blocks of interleave_blcksize; + const char * type_name; + int64_t blck_size; + int64_t blck_size_interleave; // interleave elements in blocks + size_t type_size; + bool is_quantized; + ggml_to_float_t to_float; + ggml_from_float_t from_float; + ggml_from_float_t from_float_ref; ggml_from_float_to_mat_t from_float_to_mat; - ggml_gemv_t gemv; - ggml_gemm_t gemm; + ggml_vec_dot_t vec_dot; + enum ggml_type vec_dot_type; + int64_t nrows; // number of rows to process simultaneously + int64_t ncols; // number of columns to process simultaneously + ggml_gemv_t gemv; + ggml_gemm_t gemm; } ggml_type_traits_t; GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type); diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c index 008718634..40838cf4f 100644 --- a/ggml/src/ggml-aarch64.c +++ b/ggml/src/ggml-aarch64.c @@ -20,19 +20,19 @@ // Functions to create the interleaved data layout formats -// interleave 4 block_q4_0s in blocks of interleave_blcksize +// interleave 4 block_q4_0s in blocks of blck_size_interleave // returns an interleaved block_q4_0x4 // in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks -// first, then interleave quants from 4 block_q4_0s in blocks of interleave_blcksize +// first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave // // - in : an array of block_q4_0 pointers -// - interleave_blcksize : the block_q4_0 quants bytes are interleaved in blocks of -// interleave_blcksize bytes +// - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of +// blck_size_interleave bytes // - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes // from bias offset form to pure sign form (this saves subtract // operations durin unpacking) // -static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int interleave_blcksize, unsigned int xor_mask) { +static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) { block_q4_0x4 out; for (int i = 0; i < 4; i++) { @@ -40,9 +40,9 @@ static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int interleave_b } for (int i = 0; i < QK4_0 * 2; i++) { - int src_offset = (i / (4 * interleave_blcksize)) * interleave_blcksize; - int src_id = (i % (4 * interleave_blcksize)) / interleave_blcksize; - src_offset += (i % interleave_blcksize); + int src_offset = (i / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (i % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (i % blck_size_interleave); out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; } @@ -50,11 +50,11 @@ static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int interleave_b return out; } -// interleave 8 block_q4_0s in blocks of interleave_blcksize +// interleave 8 block_q4_0s in blocks of blck_size_interleave // returns an interleaved block_q4_0x8 // in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks -// first, then interleave quants from 8 block_q4_0s in blocks of interleave_blcksize -static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int interleave_blcksize, unsigned int xor_mask) { +// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave +static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) { block_q4_0x8 out; for (int i = 0; i < 8; i++) { @@ -62,9 +62,9 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int interleave_b } for (int i = 0; i < QK4_0 * 4; i++) { - int src_offset = (i / (8 * interleave_blcksize)) * interleave_blcksize; - int src_id = (i % (8 * interleave_blcksize)) / interleave_blcksize; - src_offset += (i % interleave_blcksize); + int src_offset = (i / (8 * blck_size_interleave)) * blck_size_interleave; + int src_id = (i % (8 * blck_size_interleave)) / blck_size_interleave; + src_offset += (i % blck_size_interleave); out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; } @@ -135,7 +135,7 @@ void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) } #else // scalar - const int interleave_blcksize = 4; + const int blck_size_interleave = 4; float srcv[4][QK8_0]; float id[4]; @@ -155,12 +155,12 @@ void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) } for (int j = 0; j < QK8_0 * 4; j++) { - int src_offset = (j / (4 * interleave_blcksize)) * interleave_blcksize; - int src_id = (j % (4 * interleave_blcksize)) / interleave_blcksize; - src_offset += (j % interleave_blcksize); + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); float x0 = srcv[src_id][src_offset] * id[src_id]; - y[i].qs[j] = roundf(x0);; + y[i].qs[j] = roundf(x0); } } #endif @@ -253,7 +253,7 @@ void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) } #else // scalar - const int interleave_blcksize = 8; + const int blck_size_interleave = 8; float srcv[4][QK8_0]; float id[4]; @@ -273,26 +273,30 @@ void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) } for (int j = 0; j < QK8_0 * 4; j++) { - int src_offset = (j / (4 * interleave_blcksize)) * interleave_blcksize; - int src_id = (j % (4 * interleave_blcksize)) / interleave_blcksize; - src_offset += (j % interleave_blcksize); + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); float x0 = srcv[src_id][src_offset] * id[src_id]; - y[i].qs[j] = roundf(x0);; + y[i].qs[j] = roundf(x0); } } #endif } -void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t interleave_blcksize) { +void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) { assert(nrow == 4); UNUSED(nrow); - if (interleave_blcksize == 4) quantize_q8_0_4x4(x, vy, n_per_row); - else if (interleave_blcksize == 8) quantize_q8_0_4x8(x, vy, n_per_row); - else assert(false); + if (blck_size_interleave == 4) { + quantize_q8_0_4x4(x, vy, n_per_row); + } else if (blck_size_interleave == 8) { + quantize_q8_0_4x8(x, vy, n_per_row); + } else { + assert(false); + } } -static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int interleave_blcksize) { +static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int blck_size_interleave) { assert(n_per_row % QK4_0 == 0); const int nb = n_per_row / QK4_0; @@ -311,15 +315,15 @@ static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict ds for (int64_t x = 0; x < nb; x++) { for (int i = 0; i < nrows_interleaved; i++ ) { - quantize_row_q4_0_reference(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0); + quantize_row_q4_0_ref(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0); } if (nrows_interleaved == 8) { - *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, interleave_blcksize, 0x88); + *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave, 0x88); out_ptr = (block_q4_0x8 *) out_ptr + 1; } else if (nrows_interleaved == 4) { - *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, interleave_blcksize, 0x88); + *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave, 0x88); out_ptr = (block_q4_0x4 *) out_ptr + 1; } } diff --git a/ggml/src/ggml-aarch64.h b/ggml/src/ggml-aarch64.h index 65ead1efe..517babaf1 100644 --- a/ggml/src/ggml-aarch64.h +++ b/ggml/src/ggml-aarch64.h @@ -16,7 +16,7 @@ extern "C" { void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t interleave_blcksize); +void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave); // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); @@ -24,14 +24,14 @@ size_t quantize_q4_0_4x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT d size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); // GEMV -void ggml_gemv_q4_0_4x4_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q4_0_4x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q4_0_8x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); // GEMM -void ggml_gemm_q4_0_4x4_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q4_0_4x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q4_0_8x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #ifdef __cplusplus } diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index cbe377cf5..1839a722e 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -658,7 +658,7 @@ static inline __m128i packNibbles( __m256i bytes ) { #endif //__loongarch_asx // reference implementation for deterministic creation of model files -void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { +void quantize_row_q4_0_ref(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; assert(k % qk == 0); @@ -696,11 +696,11 @@ void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict } void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q4_0_reference(x, y, k); + quantize_row_q4_0_ref(x, y, k); } -void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int64_t k) { +void quantize_row_q4_1_ref(const float * restrict x, block_q4_1 * restrict y, int64_t k) { const int qk = QK4_1; assert(k % qk == 0); @@ -738,10 +738,10 @@ void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict } void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q4_1_reference(x, y, k); + quantize_row_q4_1_ref(x, y, k); } -void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int64_t k) { +void quantize_row_q5_0_ref(const float * restrict x, block_q5_0 * restrict y, int64_t k) { static const int qk = QK5_0; assert(k % qk == 0); @@ -786,10 +786,10 @@ void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict } void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q5_0_reference(x, y, k); + quantize_row_q5_0_ref(x, y, k); } -void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int64_t k) { +void quantize_row_q5_1_ref(const float * restrict x, block_q5_1 * restrict y, int64_t k) { const int qk = QK5_1; assert(k % qk == 0); @@ -834,11 +834,11 @@ void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict } void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q5_1_reference(x, y, k); + quantize_row_q5_1_ref(x, y, k); } // reference implementation for deterministic creation of model files -void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int64_t k) { +void quantize_row_q8_0_ref(const float * restrict x, block_q8_0 * restrict y, int64_t k) { assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -1144,12 +1144,12 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) #else GGML_UNUSED(nb); // scalar - quantize_row_q8_0_reference(x, y, k); + quantize_row_q8_0_ref(x, y, k); #endif } // reference implementation for deterministic creation of model files -void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int64_t k) { +void quantize_row_q8_1_ref(const float * restrict x, block_q8_1 * restrict y, int64_t k) { assert(QK8_1 == 32); assert(k % QK8_1 == 0); const int nb = k / QK8_1; @@ -1508,7 +1508,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) #else GGML_UNUSED(nb); // scalar - quantize_row_q8_1_reference(x, y, k); + quantize_row_q8_1_ref(x, y, k); #endif } @@ -1899,7 +1899,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * //========================- 2-bit (de)-quantization -void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int64_t k) { +void quantize_row_q2_K_ref(const float * restrict x, block_q2_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -2002,7 +2002,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int6 } void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) { - quantize_row_q2_K_reference(x, vy, k); + quantize_row_q2_K_ref(x, vy, k); } static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights, @@ -2226,7 +2226,7 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row); if (!quant_weights) { - quantize_row_q2_K_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q2_K_ref(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; @@ -2241,7 +2241,7 @@ size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nr //========================= 3-bit (de)-quantization -void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int64_t k) { +void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -2368,7 +2368,7 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int6 } void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) { - quantize_row_q3_K_reference(x, vy, k); + quantize_row_q3_K_ref(x, vy, k); } static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) { @@ -2458,7 +2458,7 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row); if (!quant_weights) { - quantize_row_q3_K_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q3_K_ref(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; @@ -2473,7 +2473,7 @@ size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nr // ====================== 4-bit (de)-quantization -void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int64_t k) { +void quantize_row_q4_K_ref(const float * restrict x, block_q4_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -2572,7 +2572,7 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int6 void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q4_K * restrict y = vy; - quantize_row_q4_K_reference(x, y, k); + quantize_row_q4_K_ref(x, y, k); } static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) { @@ -2651,7 +2651,7 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row); if (!quant_weights) { - quantize_row_q4_K_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q4_K_ref(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; @@ -2666,7 +2666,7 @@ size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nr // ====================== 5-bit (de)-quantization -void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int64_t k) { +void quantize_row_q5_K_ref(const float * restrict x, block_q5_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int64_t nb = k / QK_K; @@ -2783,7 +2783,7 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int6 void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q5_K * restrict y = vy; - quantize_row_q5_K_reference(x, y, k); + quantize_row_q5_K_ref(x, y, k); } static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) { @@ -2882,7 +2882,7 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row); if (!quant_weights) { - quantize_row_q5_K_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q5_K_ref(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; @@ -2897,7 +2897,7 @@ size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nr // ====================== 6-bit (de)-quantization -void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int64_t k) { +void quantize_row_q6_K_ref(const float * restrict x, block_q6_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int64_t nb = k / QK_K; @@ -3001,7 +3001,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int6 void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q6_K * restrict y = vy; - quantize_row_q6_K_reference(x, y, k); + quantize_row_q6_K_ref(x, y, k); } static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) { @@ -3091,7 +3091,7 @@ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restri size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row); if (!quant_weights) { - quantize_row_q6_K_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q6_K_ref(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; @@ -3108,7 +3108,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri static_assert(QK4_0 == 32, "QK4_0 must be 32"); if (!quant_weights) { - quantize_row_q4_0_reference(x, y, n_per_row); + quantize_row_q4_0_ref(x, y, n_per_row); return; } @@ -3134,7 +3134,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q4_0_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q4_0_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row); } size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row); @@ -3151,7 +3151,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri static_assert(QK4_1 == 32, "QK4_1 must be 32"); if (!quant_weights) { - quantize_row_q4_1_reference(x, y, n_per_row); + quantize_row_q4_1_ref(x, y, n_per_row); return; } @@ -3179,7 +3179,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri size_t quantize_q4_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q4_1_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q4_1_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row); } size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row); @@ -3196,7 +3196,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri static_assert(QK5_0 == 32, "QK5_0 must be 32"); if (!quant_weights) { - quantize_row_q5_0_reference(x, y, n_per_row); + quantize_row_q5_0_ref(x, y, n_per_row); return; } @@ -3233,7 +3233,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri size_t quantize_q5_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q5_0_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q5_0_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row); } size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row); @@ -3250,7 +3250,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri static_assert(QK5_1 == 32, "QK5_1 must be 32"); if (!quant_weights) { - quantize_row_q5_1_reference(x, y, n_per_row); + quantize_row_q5_1_ref(x, y, n_per_row); return; } @@ -3286,7 +3286,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q5_1_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q5_1_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row); } size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row); @@ -3302,7 +3302,7 @@ size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nr size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { (void)quant_weights; // not used const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row); - quantize_row_q8_0_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q8_0_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * row_size; } @@ -3590,7 +3590,7 @@ void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, //===================================== Q8_K ============================================== -void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int64_t k) { +void quantize_row_q8_K_ref(const float * restrict x, block_q8_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int64_t nb = k / QK_K; @@ -3641,7 +3641,7 @@ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int6 } void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q8_K_reference(x, y, k); + quantize_row_q8_K_ref(x, y, k); } //===================================== Dot ptoducts ================================= @@ -13530,10 +13530,10 @@ size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int64_t void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq3_xxs * restrict y = vy; - quantize_row_iq3_xxs_reference(x, y, k); + quantize_row_iq3_xxs_ref(x, y, k); } -void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) { +void quantize_row_iq3_xxs_ref(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_row_iq3_xxs_impl(256, x, y, k, NULL); } @@ -13746,10 +13746,10 @@ size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int64_t n void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq3_s * restrict y = vy; - quantize_row_iq3_s_reference(x, y, k); + quantize_row_iq3_s_ref(x, y, k); } -void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int64_t k) { +void quantize_row_iq3_s_ref(const float * restrict x, block_iq3_s * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_iq3_s(x, y, 1, k, NULL); } @@ -14487,7 +14487,7 @@ void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int64_t k } } -void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int64_t k) { +void quantize_row_iq4_nl_ref(const float * restrict x, block_iq4_nl * restrict y, int64_t k) { assert(k % QK4_NL == 0); quantize_row_iq4_nl(x, y, k); } @@ -14515,10 +14515,10 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq4_xs * restrict y = vy; - quantize_row_iq4_xs_reference(x, y, k); + quantize_row_iq4_xs_ref(x, y, k); } -void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int64_t k) { +void quantize_row_iq4_xs_ref(const float * restrict x, block_iq4_xs * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_iq4_xs(x, y, 1, k, NULL); } @@ -14705,7 +14705,7 @@ size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int64_t n return nrow * nblock * sizeof(block_iq2_s); } -void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restrict y, int64_t k) { +void quantize_row_iq2_s_ref(const float * restrict x, block_iq2_s * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_iq2_s(x, y, 1, k, NULL); } @@ -14713,7 +14713,7 @@ void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restri void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq2_s * restrict y = vy; - quantize_row_iq2_s_reference(x, y, k); + quantize_row_iq2_s_ref(x, y, k); } static bool validate_float(float f, size_t i) { diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 30983b872..88b1f3269 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -12,25 +12,25 @@ extern "C" { #endif // Quantization -void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); -void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); -void quantize_row_q5_0_reference(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); -void quantize_row_q5_1_reference(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k); -void quantize_row_q8_0_reference(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k); -void quantize_row_q8_1_reference(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k); -void quantize_row_q2_K_reference(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); -void quantize_row_q3_K_reference(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); -void quantize_row_q4_K_reference(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k); -void quantize_row_q5_K_reference(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k); -void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k); -void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k); -void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k); -void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k); -void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k); -void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); -void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); +void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k); +void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 1bb731e16..9a5414787 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -592,7 +592,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = false, .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row, .from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row, - .from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row, + .from_float_ref = (ggml_from_float_t) ggml_fp32_to_fp16_row, .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16, .vec_dot_type = GGML_TYPE_F16, .nrows = 1, @@ -604,7 +604,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q4_0, .from_float = quantize_row_q4_0, - .from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q4_0_ref, .vec_dot = ggml_vec_dot_q4_0_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, #if defined (__ARM_FEATURE_MATMUL_INT8) @@ -620,7 +620,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q4_1, .from_float = quantize_row_q4_1, - .from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref, .vec_dot = ggml_vec_dot_q4_1_q8_1, .vec_dot_type = GGML_TYPE_Q8_1, #if defined (__ARM_FEATURE_MATMUL_INT8) @@ -636,7 +636,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = false, .to_float = NULL, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = NULL, .vec_dot_type = GGML_TYPE_COUNT, .nrows = 1, @@ -648,7 +648,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = false, .to_float = NULL, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = NULL, .vec_dot_type = GGML_TYPE_COUNT, .nrows = 1, @@ -660,7 +660,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q5_0, .from_float = quantize_row_q5_0, - .from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q5_0_ref, .vec_dot = ggml_vec_dot_q5_0_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, @@ -672,7 +672,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q5_1, .from_float = quantize_row_q5_1, - .from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q5_1_ref, .vec_dot = ggml_vec_dot_q5_1_q8_1, .vec_dot_type = GGML_TYPE_Q8_1, .nrows = 1, @@ -684,7 +684,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q8_0, .from_float = quantize_row_q8_0, - .from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q8_0_ref, + .from_float_to_mat = quantize_mat_q8_0, .vec_dot = ggml_vec_dot_q8_0_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, #if defined (__ARM_FEATURE_MATMUL_INT8) @@ -692,7 +693,6 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { #else .nrows = 1, #endif - .from_float_to_mat = quantize_mat_q8_0, }, [GGML_TYPE_Q8_1] = { .type_name = "q8_1", @@ -700,7 +700,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_q8_1), .is_quantized = true, .from_float = quantize_row_q8_1, - .from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref, .vec_dot_type = GGML_TYPE_Q8_1, .nrows = 1, }, @@ -711,7 +711,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q2_K, .from_float = quantize_row_q2_K, - .from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q2_K_ref, .vec_dot = ggml_vec_dot_q2_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -723,7 +723,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q3_K, .from_float = quantize_row_q3_K, - .from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q3_K_ref, .vec_dot = ggml_vec_dot_q3_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -735,7 +735,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q4_K, .from_float = quantize_row_q4_K, - .from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q4_K_ref, .vec_dot = ggml_vec_dot_q4_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -747,7 +747,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q5_K, .from_float = quantize_row_q5_K, - .from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q5_K_ref, .vec_dot = ggml_vec_dot_q5_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -759,7 +759,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q6_K, .from_float = quantize_row_q6_K, - .from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q6_K_ref, .vec_dot = ggml_vec_dot_q6_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -771,7 +771,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq2_xxs, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = ggml_vec_dot_iq2_xxs_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -783,7 +783,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq2_xs, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = ggml_vec_dot_iq2_xs_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -795,7 +795,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq3_xxs, .from_float = quantize_row_iq3_xxs, - .from_float_reference = (ggml_from_float_t)quantize_row_iq3_xxs_reference, + .from_float_ref = (ggml_from_float_t)quantize_row_iq3_xxs_ref, .vec_dot = ggml_vec_dot_iq3_xxs_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -807,7 +807,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq3_s, .from_float = quantize_row_iq3_s, - .from_float_reference = (ggml_from_float_t)quantize_row_iq3_s_reference, + .from_float_ref = (ggml_from_float_t)quantize_row_iq3_s_ref, .vec_dot = ggml_vec_dot_iq3_s_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -819,7 +819,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq2_s, .from_float = quantize_row_iq2_s, - .from_float_reference = (ggml_from_float_t)quantize_row_iq2_s_reference, + .from_float_ref = (ggml_from_float_t)quantize_row_iq2_s_ref, .vec_dot = ggml_vec_dot_iq2_s_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -831,7 +831,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq1_s, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = ggml_vec_dot_iq1_s_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -843,7 +843,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq1_m, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = ggml_vec_dot_iq1_m_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -855,7 +855,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq4_nl, .from_float = quantize_row_iq4_nl, - .from_float_reference = (ggml_from_float_t)quantize_row_iq4_nl_reference, + .from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref, .vec_dot = ggml_vec_dot_iq4_nl_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, @@ -867,7 +867,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq4_xs, .from_float = quantize_row_iq4_xs, - .from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference, + .from_float_ref = (ggml_from_float_t)quantize_row_iq4_xs_ref, .vec_dot = ggml_vec_dot_iq4_xs_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -886,7 +886,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = false, .to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row, .from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row, - .from_float_reference = (ggml_from_float_t) ggml_fp32_to_bf16_row, + .from_float_ref = (ggml_from_float_t) ggml_fp32_to_bf16_row, .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16, .vec_dot_type = GGML_TYPE_BF16, .nrows = 1, @@ -894,48 +894,48 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_0_4_4] = { .type_name = "q4_0_4x4", .blck_size = QK4_0, + .blck_size_interleave = 4, .type_size = sizeof(block_q4_0), .is_quantized = true, .to_float = NULL, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = NULL, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, .ncols = 4, - .interleave_blcksize = 4, .gemv = ggml_gemv_q4_0_4x4_q8_0, .gemm = ggml_gemm_q4_0_4x4_q8_0, }, [GGML_TYPE_Q4_0_4_8] = { .type_name = "q4_0_4x8", .blck_size = QK4_0, + .blck_size_interleave = 8, .type_size = sizeof(block_q4_0), .is_quantized = true, .to_float = NULL, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = NULL, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, .ncols = 4, - .interleave_blcksize = 8, .gemv = ggml_gemv_q4_0_4x8_q8_0, .gemm = ggml_gemm_q4_0_4x8_q8_0, }, [GGML_TYPE_Q4_0_8_8] = { .type_name = "q4_0_8x8", .blck_size = QK4_0, + .blck_size_interleave = 8, .type_size = sizeof(block_q4_0), .is_quantized = true, .to_float = NULL, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = NULL, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, .ncols = 8, - .interleave_blcksize = 8, .gemv = ggml_gemv_q4_0_8x8_q8_0, .gemm = ggml_gemm_q4_0_8x8_q8_0, } @@ -3115,7 +3115,7 @@ size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) { return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN); } -GGML_CALL int ggml_blck_size(enum ggml_type type) { +GGML_CALL int64_t ggml_blck_size(enum ggml_type type) { return type_traits[type].blck_size; } @@ -12192,15 +12192,14 @@ static void ggml_compute_forward_mul_mat( const enum ggml_type type = src0->type; - enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; - ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; - int64_t const vec_dot_num_rows = type_traits[type].nrows; - int64_t const matmul_num_cols = type_traits[type].ncols; - int64_t const interleave_blcksize = type_traits[type].interleave_blcksize; - ggml_from_float_to_mat_t const from_float_to_mat - = type_traits[vec_dot_type].from_float_to_mat; - ggml_gemv_t const gemv = type_traits[type].gemv; - ggml_gemm_t const gemm = type_traits[type].gemm; + enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; + ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float; + ggml_from_float_to_mat_t const from_float_to_mat = type_traits[vec_dot_type].from_float_to_mat; + int64_t const vec_dot_num_rows = type_traits[type].nrows; + int64_t const matmul_num_cols = type_traits[type].ncols; + int64_t const blck_size_interleave = type_traits[type].blck_size_interleave; + ggml_gemv_t const gemv = type_traits[type].gemv; + ggml_gemm_t const gemm = type_traits[type].gemm; GGML_ASSERT(ne0 == ne01); GGML_ASSERT(ne1 == ne11); @@ -12264,14 +12263,14 @@ UseGgmlGemm1:; for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), - 4, ne10, interleave_blcksize); + 4, ne10, blck_size_interleave); } i11_processed = ne11 - ne11 % 4; } for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { - from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), - (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), - ne10); + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + ne10); } } } @@ -12355,7 +12354,7 @@ UseGgmlGemm2:; int64_t src0_start = (ith * ne01) / nth; int64_t src0_end = ((ith + 1) * ne01) / nth; src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start; - src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end; + src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end; if (src0_start >= src0_end) return; // If there are more than three rows in src1, use gemm; otherwise, use gemv. @@ -12413,11 +12412,11 @@ static void ggml_compute_forward_mul_mat_id( const bool src1_cont = ggml_is_contiguous(src1); - ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; - enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; - ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; - int64_t const matmul_num_cols = type_traits[type].ncols; - ggml_gemv_t const gemv = type_traits[type].gemv; + ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; + enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; + ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float; + int64_t const matmul_num_cols = type_traits[type].ncols; + ggml_gemv_t const gemv = type_traits[type].gemv; // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == ggml_type_size(type)); @@ -12458,9 +12457,9 @@ static void ggml_compute_forward_mul_mat_id( for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i11 = ith; i11 < ne11; i11 += nth) { - from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), - (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), - ne10); + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + ne10); } } } @@ -21063,8 +21062,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p (int64_t) info->ne[3]; if (ne % ggml_blck_size(info->type) != 0) { - fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%d)\n", - __func__, info->name.data, (int)info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type)); + fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n", + __func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type)); fclose(file); gguf_free(ctx); return NULL; diff --git a/tests/test-double-float.cpp b/tests/test-double-float.cpp index 753dae911..6aac4737a 100644 --- a/tests/test-double-float.cpp +++ b/tests/test-double-float.cpp @@ -14,7 +14,7 @@ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wdouble-promotion" -// ggml.c::quantize_row_q4_0_reference +// ggml.c::quantize_row_q4_0_ref inline static uint8_t round_orig(float v0) { return ((int8_t) (round(v0))) + 8; } // ggml.c::ggml_silu_f32 @@ -24,7 +24,7 @@ inline static float silu_orig(float x) { #pragma GCC diagnostic pop -// ggml.c::quantize_row_q4_0_reference +// ggml.c::quantize_row_q4_0_ref inline static uint8_t round_float(float v0) { return (int8_t)roundf(v0) + 8; } // ggml.c::ggml_silu_f32 diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index e690ac6c8..c97458d1d 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -60,7 +60,7 @@ static float reference_quantization_error(ggml_type_traits_t & qfns, size_t test qfns.from_float(test_data, tmp_q.data(), test_size); qfns.to_float(tmp_q.data(), tmp_out.data(), test_size); - qfns.from_float_reference(test_data, tmp_q.data(), test_size); + qfns.from_float_ref(test_data, tmp_q.data(), test_size); qfns.to_float(tmp_q.data(), tmp_out_ref.data(), test_size); return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size); diff --git a/tests/test-quantize-perf.cpp b/tests/test-quantize-perf.cpp index 48d9fae3d..24e066053 100644 --- a/tests/test-quantize-perf.cpp +++ b/tests/test-quantize-perf.cpp @@ -285,7 +285,7 @@ int main(int argc, char * argv[]) { for (size_t size : params.test_sizes) { printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024)); auto quantize_fn = [&](void) -> float { - qfns.from_float_reference(test_data1, test_q1, size); + qfns.from_float_ref(test_data1, test_q1, size); return test_q1[0]; }; size_t quantized_size = ggml_row_size(type, size); From 71c1121d11f1437be9421fd0cbaa011b9ef49098 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 12 Jul 2024 10:46:14 +0300 Subject: [PATCH 06/13] examples : sprintf -> snprintf (#8434) * examples : sprintf -> snprintf ggml-ci * examples : use sizeof() instead of hardcoded constants --- examples/eval-callback/eval-callback.cpp | 2 +- examples/gguf-hash/gguf-hash.cpp | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 64cd338c2..c8a3016a4 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -99,7 +99,7 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { char src1_str[128] = {0}; if (src1) { - sprintf(src1_str, "%s{%s}", src1->name, ggml_ne_string(src1).c_str()); + snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str()); } printf("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, diff --git a/examples/gguf-hash/gguf-hash.cpp b/examples/gguf-hash/gguf-hash.cpp index c34728c3d..e96c75117 100644 --- a/examples/gguf-hash/gguf-hash.cpp +++ b/examples/gguf-hash/gguf-hash.cpp @@ -347,7 +347,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) { char hex_result[17]; for (int offset = 0; offset < 8; offset++) { unsigned int shift_bits_by = (8 * (8 - offset - 1)); - sprintf( ( hex_result + (2*offset)), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff); + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff); } if (hash_params.manifest_is_usable) { @@ -384,7 +384,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) { char hex_result[41] = {0}; for (int offset = 0; offset < 20; offset++) { - sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff); + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff); } if (hash_params.manifest_is_usable) { @@ -421,7 +421,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) { char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0}; for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) { - sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff); + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff); } if (hash_params.manifest_is_usable) { @@ -460,7 +460,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) { char hex_result[17]; for (int offset = 0; offset < 8; offset++) { unsigned int shift_bits_by = (8 * (8 - offset - 1)); - sprintf( ( hex_result + (2*offset)), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff); + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff); } if (hash_params.manifest_is_usable) { @@ -490,7 +490,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) { char hex_result[41]; for (int offset = 0; offset < 20; offset++) { - sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff); + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff); } if (hash_params.manifest_is_usable) { @@ -520,7 +520,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) { char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0}; for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) { - sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff); + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff); } if (hash_params.manifest_is_usable) { @@ -552,7 +552,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) { generate_uuidv5(result, uuid); char string_buffer[37] = {0}; - sprintf(string_buffer, "%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", + snprintf(string_buffer, sizeof(string_buffer), "%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", uuid[0], uuid[1], uuid[2], uuid[3], uuid[4], uuid[5], uuid[6], uuid[7], uuid[8], uuid[9], uuid[10], uuid[11], From 5aefbce27a66473d4b1263ba3f7bdd3d14245975 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ji=C5=99=C3=AD=20Podiv=C3=ADn?= <66251151+jpodivin@users.noreply.github.com> Date: Fri, 12 Jul 2024 10:06:33 +0200 Subject: [PATCH 07/13] convert : remove fsep token from GPTRefactForCausalLM (#8237) The token used by Refact doesn't serve the same purpose as the from CodeGemma. Signed-off-by: Jiri Podivin --- convert_hf_to_gguf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ebb5ca376..cf930be17 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1203,11 +1203,10 @@ class RefactModel(Model): # TODO: how to determine special FIM tokens automatically? special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False, - special_token_types = ['prefix', 'suffix', 'middle', 'fsep', 'eot']) + special_token_types = ['prefix', 'suffix', 'middle', 'eot']) special_vocab._set_special_token("prefix", 1) special_vocab._set_special_token("suffix", 3) special_vocab._set_special_token("middle", 2) - special_vocab._set_special_token("fsep", 4) # is this correct? special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): From 8a4441ea1a2564578134404f31158c318e9c0bf3 Mon Sep 17 00:00:00 2001 From: Armen Kaleshian Date: Fri, 12 Jul 2024 04:08:19 -0400 Subject: [PATCH 08/13] docker : fix filename for convert-hf-to-gguf.py in tools.sh (#8441) Commit b0a4699 changed the name of this script from convert-hf-to-gguf.py to convert_hf_to_gguf.py breaking how convert is called from within a Docker container. --- .devops/tools.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.devops/tools.sh b/.devops/tools.sh index 335382f69..cf0e8f32d 100755 --- a/.devops/tools.sh +++ b/.devops/tools.sh @@ -8,7 +8,7 @@ arg1="$1" shift if [[ "$arg1" == '--convert' || "$arg1" == '-c' ]]; then - python3 ./convert-hf-to-gguf.py "$@" + python3 ./convert_hf_to_gguf.py "$@" elif [[ "$arg1" == '--quantize' || "$arg1" == '-q' ]]; then ./llama-quantize "$@" elif [[ "$arg1" == '--run' || "$arg1" == '-r' ]]; then From c3ebcfa148e867a68e78fd5c4f0c23e8f84c788b Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Fri, 12 Jul 2024 03:14:12 -0500 Subject: [PATCH 09/13] server : ensure batches are either all embed or all completion (#8420) * make sure batches are all embed or all non-embed * non-embedding batch for sampled tokens; fix unused params warning --- examples/server/server.cpp | 38 ++++++++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index efd426289..badeb9121 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2005,6 +2005,11 @@ struct server_context { int32_t n_batch = llama_n_batch(ctx); int32_t n_ubatch = llama_n_ubatch(ctx); + // track if this is an embedding or non-embedding batch + // if we've added sampled tokens above, we are in non-embedding mode + // -1: none, 0: non-embedding, 1: embedding + int32_t batch_type = batch.n_tokens > 0 ? 0 : -1; + // next, batch any pending prompts without exceeding n_batch if (params.cont_batching || batch.n_tokens == 0) { for (auto & slot : slots) { @@ -2175,6 +2180,14 @@ struct server_context { } } + // check that we are in the right batch_type, if not defer the slot + bool slot_type = slot.embedding ? 1 : 0; + if (batch_type == -1) { + batch_type = slot_type; + } else if (batch_type != slot_type) { + continue; + } + // keep only the common part int p0 = (int) system_tokens.size() + slot.n_past; if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) { @@ -2276,6 +2289,9 @@ struct server_context { {"n_tokens", batch.n_tokens}, }); + // make sure we're in the right embedding mode + llama_set_embeddings(ctx, batch_type == 1); + // process the created batch of tokens for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); @@ -2990,6 +3006,11 @@ int main(int argc, char ** argv) { }; const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { + if (ctx_server.params.embedding) { + res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = json::parse(req.body); @@ -3085,6 +3106,11 @@ int main(int argc, char ** argv) { }; const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { + if (ctx_server.params.embedding) { + res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); @@ -3157,6 +3183,11 @@ int main(int argc, char ** argv) { }; const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { + if (ctx_server.params.embedding) { + res_error(res, format_error_response("This server does not support infill. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = json::parse(req.body); @@ -3243,13 +3274,8 @@ int main(int argc, char ** argv) { return res.set_content(data.dump(), "application/json; charset=utf-8"); }; - const auto handle_embeddings = [¶ms, &ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { + const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - if (!params.embedding) { - res.status = 501; - res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8"); - return; - } const json body = json::parse(req.body); bool is_openai = false; From f53226245f421bd01b47cce43a47e791de82c636 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 12 Jul 2024 11:05:21 +0200 Subject: [PATCH 10/13] llama : suppress unary minus operator warning (#8448) This commit updates the _try_copy lambda and moves the unary minus operator to after the cast to int32_t. The motivation for this that currently the following warning is generated on windows: ```console llama.cpp\src\llama.cpp(21147,30): warning C4146: unary minus operator applied to unsigned type, result still unsigned ``` --- src/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index f91ac7779..59b76a6d8 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -21144,7 +21144,7 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token size--; } if (length < (int32_t)size) { - return (int32_t) -size; + return -(int32_t) size; } memcpy(buf, token, size); return (int32_t) size; From 6af51c0d96e6268769fc05c98d5b1a5e832c0017 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 12 Jul 2024 14:48:04 +0300 Subject: [PATCH 11/13] main : print error on empty input (#8456) --- examples/main/main.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 4ef55c1e6..a0d817b1a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -289,8 +289,13 @@ int main(int argc, char ** argv) { // Should not run without any tokens if (embd_inp.empty()) { - embd_inp.push_back(llama_token_bos(model)); - LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); + if (add_bos) { + embd_inp.push_back(llama_token_bos(model)); + LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); + } else { + LOG_TEE("error: input is empty\n"); + return -1; + } } // Tokenize negative prompt From 4e24cffd8cccd653634e24ee461c252bd77b1426 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 12 Jul 2024 14:48:15 +0300 Subject: [PATCH 12/13] server : handle content array in chat API (#8449) * server : handle content array in chat API * Update examples/server/utils.hpp Co-authored-by: Xuan Son Nguyen --------- Co-authored-by: Xuan Son Nguyen --- examples/server/utils.hpp | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 7ef2a519a..db6b3b74d 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -122,8 +122,26 @@ inline std::string format_chat(const struct llama_model * model, const std::stri for (size_t i = 0; i < messages.size(); ++i) { const auto & curr_msg = messages[i]; - std::string role = json_value(curr_msg, "role", std::string("")); - std::string content = json_value(curr_msg, "content", std::string("")); + + std::string role = json_value(curr_msg, "role", std::string("")); + + std::string content; + if (curr_msg.contains("content")) { + if (curr_msg["content"].is_string()) { + content = curr_msg["content"].get(); + } else if (curr_msg["content"].is_array()) { + for (const auto & part : curr_msg["content"]) { + if (part.contains("text")) { + content += "\n" + part["text"].get(); + } + } + } else { + throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); + } + } else { + throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); + } + chat.push_back({role, content}); } From c917b67f06c42d8ca8391b9bc73f5fe62c83bf70 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 13 Jul 2024 18:32:33 +0300 Subject: [PATCH 13/13] metal : template-ify some of the kernels (#8447) ggml-ci --- ggml/src/ggml-metal.m | 28 +- ggml/src/ggml-metal.metal | 723 ++++++++++---------------------------- 2 files changed, 190 insertions(+), 561 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 79902c9a8..b5939efa6 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -193,16 +193,16 @@ enum ggml_metal_kernel_type { //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 - GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, + GGML_METAL_KERNEL_TYPE_CPY_F32_F16, + GGML_METAL_KERNEL_TYPE_CPY_F16_F16, + GGML_METAL_KERNEL_TYPE_CPY_F16_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, - GGML_METAL_KERNEL_TYPE_CPY_F16_F16, - GGML_METAL_KERNEL_TYPE_CPY_F16_F32, GGML_METAL_KERNEL_TYPE_CONCAT, GGML_METAL_KERNEL_TYPE_SQR, GGML_METAL_KERNEL_TYPE_SUM_ROWS, @@ -651,14 +651,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); @@ -810,8 +810,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const switch (op->src[0]->type) { case GGML_TYPE_F32: switch (op->type) { - case GGML_TYPE_F16: case GGML_TYPE_F32: + case GGML_TYPE_F16: case GGML_TYPE_Q8_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: @@ -824,8 +824,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const } case GGML_TYPE_F16: switch (op->type) { - case GGML_TYPE_F16: case GGML_TYPE_F32: + case GGML_TYPE_F16: return true; default: return false; @@ -837,7 +837,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_DIAG_MASK_INF: case GGML_OP_GET_ROWS: { - return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1; + return op->ne[3] == 1; } default: return false; @@ -1580,8 +1580,8 @@ static enum ggml_status ggml_metal_graph_compute( // some Metal matrix data types require aligned pointers // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) switch (src0->type) { - case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; - case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; + case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; + case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; default: break; } @@ -2775,8 +2775,8 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); switch (dstt) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; @@ -2789,8 +2789,8 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_F16: { switch (dstt) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; default: GGML_ASSERT(false && "not implemented"); }; } break; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index c3503479b..2a3b0c0a6 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1219,9 +1219,10 @@ kernel void kernel_mul_mv_q8_0_f32( kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } -#define N_F32_F32 4 +#define N_MV_T_T 4 -void kernel_mul_mv_f32_f32_impl( +template +void kernel_mul_mv_impl( device const char * src0, device const char * src1, device float * dst, @@ -1239,13 +1240,12 @@ void kernel_mul_mv_f32_f32_impl( uint64_t nb12, int64_t ne0, int64_t ne1, - uint r2, - uint r3, - uint3 tgpig, - uint tiisg) { - + uint r2, + uint r3, + uint3 tgpig, + uint tiisg) { const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F32_F32; + const int64_t rb = tgpig.y*N_MV_T_T; const int64_t im = tgpig.z; const uint i12 = im%ne12; @@ -1253,20 +1253,20 @@ void kernel_mul_mv_f32_f32_impl( const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - device const float * x = (device const float *) (src0 + offset0); + device const T0 * x = (device const T0 *) (src0 + offset0); if (ne00 < 128) { - for (int row = 0; row < N_F32_F32; ++row) { + for (int row = 0; row < N_MV_T_T; ++row) { int r1 = rb + row; if (r1 >= ne11) { break; } - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12); float sumf = 0; for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; + sumf += (T0) x[i] * (T1) y[i]; } float all_sum = simd_sum(sumf); @@ -1275,32 +1275,32 @@ void kernel_mul_mv_f32_f32_impl( } } } else { - device const float4 * x4 = (device const float4 *)x; - for (int row = 0; row < N_F32_F32; ++row) { + device const T04 * x4 = (device const T04 *) x; + for (int row = 0; row < N_MV_T_T; ++row) { int r1 = rb + row; if (r1 >= ne11) { break; } - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - device const float4 * y4 = (device const float4 *) y; + device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12); + device const T14 * y4 = (device const T14 *) y; float sumf = 0; for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]); dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; } } } } -[[host_name("kernel_mul_mv_f32_f32")]] -kernel void kernel_mul_mv_f32_f32( +template +kernel void kernel_mul_mv( device const char * src0, device const char * src1, device float * dst, @@ -1322,90 +1322,38 @@ kernel void kernel_mul_mv_f32_f32( constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); + kernel_mul_mv_impl( + src0, + src1, + dst, + ne00, + ne01, + ne02, + nb00, + nb01, + nb02, + ne10, + ne11, + ne12, + nb10, + nb11, + nb12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg); } -#define N_F16_F16 4 +typedef decltype(kernel_mul_mv) mul_mv_t; -kernel void kernel_mul_mv_f16_f16( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { +template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv; +template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv; +template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv; - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F16_F16; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const half * x = (device const half *) (src0 + offset0); - - if (ne00 < 128) { - for (int row = 0; row < N_F16_F16; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (half) x[i] * (half) y[i]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - device const half4 * x4 = (device const half4 *)x; - for (int row = 0; row < N_F16_F16; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); - device const half4 * y4 = (device const half4 *) y; - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -void kernel_mul_mv_f16_f32_1row_impl( +template +kernel void kernel_mul_mv_1row( device const char * src0, device const char * src1, device float * dst, @@ -1437,7 +1385,7 @@ void kernel_mul_mv_f16_f32_1row_impl( const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - device const half * x = (device const half *) (src0 + offset0); + device const T * x = (device const T *) (src0 + offset0); device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); float sumf = 0; @@ -1450,153 +1398,29 @@ void kernel_mul_mv_f16_f32_1row_impl( dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; } } else { - device const half4 * x4 = (device const half4 *) x; + device const T4 * x4 = (device const T4 *) x; device const float4 * y4 = (device const float4 *) y; + for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k]; + for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]); dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; } } } -[[host_name("kernel_mul_mv_f16_f32_1row")]] -kernel void kernel_mul_mv_f16_f32_1row( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); -} +typedef decltype(kernel_mul_mv_1row) mul_mv_1row_t; -#define N_F16_F32 4 - -void kernel_mul_mv_f16_f32_impl( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - int64_t ne10, - int64_t ne11, - int64_t ne12, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - uint3 tgpig, - uint tiisg) { - - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F16_F32; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const half * x = (device const half *) (src0 + offset0); - - if (ne00 < 128) { - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - device const half4 * x4 = (device const half4 *)x; - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - device const float4 * y4 = (device const float4 *) y; - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -[[host_name("kernel_mul_mv_f16_f32")]] -kernel void kernel_mul_mv_f16_f32( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); -} +template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row; // Assumes row size (ne00) is a multiple of 4 -kernel void kernel_mul_mv_f16_f32_l4( +template +kernel void kernel_mul_mv_l4( device const char * src0, device const char * src1, device float * dst, @@ -1628,14 +1452,14 @@ kernel void kernel_mul_mv_f16_f32_l4( const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - device const half4 * x4 = (device const half4 *) (src0 + offset0); + device const T4 * x4 = (device const T4 *) (src0 + offset0); for (int r1 = 0; r1 < nrows; ++r1) { device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); float sumf = 0; for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); } float all_sum = simd_sum(sumf); @@ -1645,6 +1469,10 @@ kernel void kernel_mul_mv_f16_f32_l4( } } +typedef decltype(kernel_mul_mv_l4) mul_mv_l4_t; + +template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4; + static float rope_yarn_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / max(0.001f, high - low); return 1.0f - min(1.0f, max(0.0f, y)); @@ -2765,9 +2593,10 @@ kernel void kernel_flash_attn_ext_vec_f16( template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; //template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; -kernel void kernel_cpy_f16_f16( - device const half * src0, - device half * dst, +template +kernel void kernel_cpy( + device const void * src0, + device void * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -2798,138 +2627,20 @@ kernel void kernel_cpy_f16_f16( const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; + device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = (T1) src[0]; } } -kernel void kernel_cpy_f16_f32( - device const half * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; +typedef decltype(kernel_cpy) kernel_cpy_t; - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_f16( - device const float * src0, - device half * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_f32( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - dst_data[i00] = src[0]; - } -} +template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy; kernel void kernel_cpy_f32_q8_0( device const float * src0, @@ -5730,9 +5441,9 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 } template -kernel void kernel_get_rows( +kernel void kernel_get_rows_q( device const void * src0, - device const char * src1, + device const void * src1, device float * dst, constant int64_t & ne00, constant uint64_t & nb01, @@ -5745,27 +5456,24 @@ kernel void kernel_get_rows( uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint3 tptg [[threads_per_threadgroup]]) { - //const int64_t i = tgpig; - //const int64_t r = ((device int32_t *) src1)[i]; - const int64_t i10 = tgpig.x; const int64_t i11 = tgpig.y; - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; const int64_t i02 = i11; for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) { float4x4 temp; - dequantize_func( - ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp); + dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp); *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; } } -kernel void kernel_get_rows_f32( +template +kernel void kernel_get_rows_f( device const void * src0, - device const char * src1, + device const void * src1, device float * dst, constant int64_t & ne00, constant uint64_t & nb01, @@ -5781,47 +5489,19 @@ kernel void kernel_get_rows_f32( const int64_t i10 = tgpig.x; const int64_t i11 = tgpig.y; - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; const int64_t i02 = i11; for (int ind = tiitg; ind < ne00; ind += tptg.x) { - ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; - } -} - -kernel void kernel_get_rows_f16( - device const void * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; - - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; - - const int64_t i02 = i11; - - for (int ind = tiitg; ind < ne00; ind += tptg.x) { - ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; + (( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind]; } } kernel void kernel_get_rows_i32( device const void * src0, - device const char * src1, + device const void * src1, device int32_t * dst, constant int64_t & ne00, constant uint64_t & nb01, @@ -5837,13 +5517,13 @@ kernel void kernel_get_rows_i32( const int64_t i10 = tgpig.x; const int64_t i11 = tgpig.y; - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; const int64_t i02 = i11; for (int ind = tiitg; ind < ne00; ind += tptg.x) { - ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; + (( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind]; } } @@ -5860,28 +5540,28 @@ kernel void kernel_get_rows_i32( #define SG_MAT_ROW 8 // each block_q contains 16*nl weights -template -void kernel_mul_mm_impl(device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { +template +kernel void kernel_mul_mm(device const uchar * src0, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup half * sa = (threadgroup half *)(shared_memory); + threadgroup T * sa = (threadgroup T *)(shared_memory); threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); const uint r0 = tgpig.y; @@ -5896,7 +5576,7 @@ void kernel_mul_mm_impl(device const uchar * src0, short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - simdgroup_half8x8 ma[4]; + simdgroup_T8x8 ma[4]; simdgroup_float8x8 mb[2]; simdgroup_float8x8 c_res[8]; for (int i = 0; i < 8; i++){ @@ -5919,7 +5599,7 @@ void kernel_mul_mm_impl(device const uchar * src0, for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { // load data and store to threadgroup memory - half4x4 temp_a; + T4x4 temp_a; dequantize_func(x, il, temp_a); threadgroup_barrier(mem_flags::mem_threadgroup); @@ -5939,7 +5619,7 @@ void kernel_mul_mm_impl(device const uchar * src0, threadgroup_barrier(mem_flags::mem_threadgroup); // load matrices from threadgroup memory and conduct outer products - threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); + threadgroup T * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); #pragma unroll(4) @@ -6115,48 +5795,6 @@ void kernel_mul_mm_id_impl( } } -template -kernel void kernel_mul_mm(device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mm_impl( - src0, - src1, - dst, - ne00, - ne02, - nb01, - nb02, - ne12, - nb10, - nb11, - nb12, - ne0, - ne1, - r2, - r3, - shared_memory, - tgpig, - tiitg, - sgitg); -} - template kernel void kernel_mul_mm_id( device const uchar * src0s, @@ -6237,69 +5875,60 @@ kernel void kernel_mul_mm_id( // get rows // -typedef void (get_rows_t)( - device const void * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3, uint, uint3); +typedef decltype(kernel_get_rows_f) get_rows_f_t; -//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; -//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; + +typedef decltype(kernel_get_rows_q) get_rows_q_t; + +template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q; // // matrix-matrix multiplication // -typedef decltype(kernel_mul_mm) mat_mm_t; +typedef decltype(kernel_mul_mm) mat_mm_t; -template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; // // indirect matrix-matrix multiplication @@ -6436,7 +6065,7 @@ void mmv_fn( impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); } -typedef decltype(mmv_fn) mul_mv_impl_fn_t; +typedef decltype(mmv_fn>) mul_mv_impl_fn_t; template kernel void kernel_mul_mv_id( @@ -6514,20 +6143,20 @@ kernel void kernel_mul_mv_id( sgitg); } -typedef decltype(kernel_mul_mv_id>) kernel_mul_mv_id_t; +typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; -template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;