diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index e12c922bd..2c51b6c42 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3791,7 +3791,7 @@ class Plamo2Model(TextModel): self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32)) self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06)) - self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1000000.0)) + self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000)) # Mamba parameters self.gguf_writer.add_ssm_state_size(hparams.get("mamba_d_state", 64)) @@ -3802,7 +3802,7 @@ class Plamo2Model(TextModel): self.gguf_writer.add_ssm_group_count(0) # MLP feed forward parameters (for attention layers) - self.gguf_writer.add_feed_forward_length(hparams.get("intermediate_size", 16384)) + self.gguf_writer.add_feed_forward_length(hparams.get("intermediate_size", 13312)) self.gguf_writer.add_file_type(self.ftype) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index d416f23ab..b54f0b25a 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -56,7 +56,7 @@ #define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16 #define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue #define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a -#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers +#define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers #define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing #define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 @@ -72,8 +72,9 @@ #define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3) #define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4) #define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4) -#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA) -#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1) +#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1) // Moore Threads #define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000 @@ -230,6 +231,10 @@ typedef float2 dfloat2; #define FP16_MMA_AVAILABLE #endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3) +#define AMD_MFMA_AVAILABLE +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3) + #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING #define NEW_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING @@ -292,6 +297,11 @@ static bool fp32_mma_hardware_available(const int cc) { return GGML_CUDA_CC_IS_CDNA(cc); } +// AMD CDNA3 matrix cores.. Will add support for other CDNA generations later. +static bool amd_mfma_available(const int cc) { + return cc >= GGML_CUDA_CC_OFFSET_AMD && GGML_CUDA_CC_IS_CDNA3(cc); +} + // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. static bool new_mma_available(const int cc) { return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING; diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 565853bfe..83cf872f6 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1330,14 +1330,16 @@ static __global__ void flash_attn_ext_f16( ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); - GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); - GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); - GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); - GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); - GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); - GGML_UNUSED(nb22); GGML_UNUSED(nb23); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); + GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); + GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); + GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); + GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) } diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 2e2ed5cd5..11778bb96 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -37,16 +37,16 @@ static __global__ void flash_attn_tile_ext_f32( #endif // FP16_MMA_AVAILABLE if (use_logit_softcap && !(D == 128 || D == 256)) { GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); - GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); + GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); - GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); - GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); - GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); - GGML_UNUSED(nb23); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); + GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); + GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); + GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); NO_DEVICE_CODE; return; } @@ -282,16 +282,16 @@ static __global__ void flash_attn_tile_ext_f32( } #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); - GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); + GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); - GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); NO_DEVICE_CODE; #endif // FLASH_ATTN_AVAILABLE } diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 8b6ad893e..1054bcf9a 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -329,16 +329,16 @@ static __global__ void flash_attn_vec_ext_f16( } #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); - GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); + GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); - GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); - GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); - GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); - GGML_UNUSED(nb23); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); + GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); + GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); + GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 2af63355a..d6817d804 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -12,7 +12,8 @@ // The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile. // All matrix tiles have ne physical 32 bit elements per warp. // -// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes. +// As described in the PTX documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes. +// The API in this file also assumes that the pointers for load_generic are aligned to 16 bytes, unaligned pointers are considered undefined behavior. #include "common.cuh" @@ -66,7 +67,44 @@ namespace ggml_cuda_mma { struct tile { static constexpr int I = I_; static constexpr int J = J_; - static constexpr int ne = I * J / WARP_SIZE; + +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) + static constexpr int ne = I * J / 64; + T x[ne] = {0}; + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> + return threadIdx.x % 16; + } else if constexpr (I == 16 && J == 8) { + return threadIdx.x % 16; + } else if constexpr (I == 32 && J == 4) { + return threadIdx.x % 32; + } else if constexpr (I == 16 && J == 16) { + return 4 * (threadIdx.x / 16) + l; + } else if constexpr (I == 32 && J == 32) { + return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4); + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> + return (2 * ((threadIdx.x / 16) % 2) + l); + } else if constexpr (I == 16 && J == 8) { + return 2 * (threadIdx.x / 16) + l; + } else if constexpr (I == 32 && J == 4) { + return 2 * (threadIdx.x / 32) + l; + } else if constexpr (I == 16 && J == 16) { + return threadIdx.x % 16; + } else if constexpr (I == 32 && J == 32) { + return threadIdx.x % 32; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } +#else + static constexpr int ne = I * J / 32; T x[ne] = {0}; static __device__ __forceinline__ int get_i(const int l) { @@ -94,6 +132,7 @@ namespace ggml_cuda_mma { static_assert(I == -1 && J == -1, "template specialization not implemented"); } } +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) }; template @@ -148,10 +187,23 @@ namespace ggml_cuda_mma { template static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { +#if defined(AMD_MFMA_AVAILABLE) + if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; + } + } else { + int64_t * xi = (int64_t *) t.x; + const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); + xi[0] = xs[0]; + } +#else #pragma unroll for (int l = 0; l < t.ne; ++l) { t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; } +#endif // defined(AMD_MFMA_AVAILABLE) } template @@ -186,7 +238,7 @@ namespace ggml_cuda_mma { template static __device__ __forceinline__ void load_ldmatrix( tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { -#ifdef NEW_MMA_AVAILABLE +#if defined(NEW_MMA_AVAILABLE) int * xi = (int * ) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" @@ -393,4 +445,60 @@ namespace ggml_cuda_mma { NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } + + static __device__ __forceinline__ void mma( + tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) { +#if defined(AMD_MFMA_AVAILABLE) + using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + int32x4_t * acc = (int32x4_t *) D.x; +#if defined(CDNA3) + acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], + ((int64_t *) B.x)[0], + acc[0], + 0, 0, 0); +#elif defined(CDNA2) || defined(CDNA) + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], + B.x[0], + acc[0], + 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], + B.x[1], + acc[0], + 0, 0, 0); +#endif // defined(CDNA3) +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) { +#if defined(AMD_MFMA_AVAILABLE) + using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int; + int32x16_t * acc = (int32x16_t *) D.x; +#if defined(CDNA3) + acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], + ((int64_t *) B.x)[0], + acc[0], + 0, 0, 0); +#elif defined(CDNA2) || defined(CDNA) + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], + B.x[0], + acc[0], + 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], + B.x[1], + acc[0], + 0, 0, 0); +#endif // defined(CDNA3) +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE + } } diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index cca4b14fb..0bb7d6e61 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -109,7 +109,8 @@ void ggml_cuda_mul_mat_q( const int64_t s03 = src0->nb[3] / ts_src0; const int64_t s3 = dst->nb[3] / ts_dst; - const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA; + const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) + || (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc))); if (!ids) { const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 + @@ -250,8 +251,9 @@ void ggml_cuda_op_mul_mat_q( // The stream-k decomposition is only faster for recent NVIDIA GPUs. // Also its fixup needs to allocate a temporary buffer in the memory pool. // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer. - const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && - ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11; + const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) + || (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc))) + && src1_ncols == ne11; const mmq_args args = { src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst, @@ -306,7 +308,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return false; } - if (new_mma_available(cc)) { + if (new_mma_available(cc) || amd_mfma_available(cc)) { return true; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 4230dd814..8aa0ceff2 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -91,7 +91,7 @@ struct tile_x_sizes { }; static int get_mmq_x_max_host(const int cc) { - return new_mma_available(cc) ? 128 : + return (amd_mfma_available(cc) || new_mma_available(cc)) ? 128 : GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? #ifdef GGML_CUDA_FORCE_MMQ 128 : 64; @@ -101,12 +101,12 @@ static int get_mmq_x_max_host(const int cc) { } static constexpr __device__ int get_mmq_x_max_device() { -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) return 128; -#else // NEW_MMA_AVAILABLE +#else // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) - return 128; + return 64; #else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA @@ -116,12 +116,11 @@ static constexpr __device__ int get_mmq_x_max_device() { return MMQ_DP4A_MAX_BATCH_SIZE; #endif // GGML_CUDA_FORCE_MMQ #else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA - return 64; #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } static int get_mmq_y_host(const int cc) { @@ -145,16 +144,25 @@ static constexpr __device__ int get_mmq_y_device() { #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) } -#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} -#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0} -#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0} -#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*4/QI8_0 + mmq_y/(QI8_0/4), 0} -#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0} -#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0} -#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +// Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes. +// The K dimension of the tiles has either, +// 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K), +// 32 bit elements for the quantized data (does not include scales). +// In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K. +// The final tile size in K direction is padded to avoid shared memory bank conflicts, +// in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma. +#define MMQ_TILE_NE_K 32 + +#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0 + mmq_y/QI4_0, 0} +#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1 + mmq_y/QI4_1, 0} +#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0} +#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI8_0 + mmq_y/(QI8_0/4), 0} +#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_1 + mmq_y/(QI8_1/2), 0} +#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K + mmq_y, 0} +#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI5_K + mmq_y/QI5_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI6_K + mmq_y/QI6_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { switch (type) { @@ -180,11 +188,11 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml } } -#define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4) -#define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4) -#define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE + 4) -#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/2 + 4) -#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7) +#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) +#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) +#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7) static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding."); @@ -216,42 +224,80 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { } } -#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) +// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales) +#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1) static int mmq_get_granularity_host(const int mmq_x, const int cc) { - return new_mma_available(cc) && mmq_x >= 48 ? 16 : 8; + if (amd_mfma_available(cc)) { + return mmq_x >= 128 ? 32 : 16; + } else if (new_mma_available(cc) && mmq_x >= 48) { + return 16; + } else { + return 8; + } } -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) +static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { + return mmq_x >= 128 ? 32 : 16; +} +#elif defined(NEW_MMA_AVAILABLE) static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { return mmq_x >= 48 ? 16 : 8; } #else -static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) { +static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) { return 8; } -#endif // NEW_MMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE + +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) +static int mmq_get_nwarps_host(const int cc) { + return amd_mfma_available(cc) ? 8 : 4; +} +#else +static int mmq_get_nwarps_host(const int /*cc*/) { + return 8; +} +#endif // (GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) + +static constexpr __device__ int mmq_get_nwarps_device() { +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) +#if defined(AMD_MFMA_AVAILABLE) + return 8; +#else + return 4; +#endif // AMD_MFMA_AVAILABLE +#else + return 8; +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) +} // ------------------------------------------------------------ -template static __device__ __forceinline__ void load_tiles_q4_0( +template static __device__ __forceinline__ void load_tiles_q4_0( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + 2*WARP_SIZE); + float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kbx = threadIdx.x / QI4_0; - const int kqsx = threadIdx.x % QI4_0; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI4_0; + const int kqsx = txi % QI4_0; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); @@ -260,20 +306,21 @@ template static __device__ __forceinlin const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b2(bxi->qs, kqsx); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808); #else - x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) { - int i = i0 + threadIdx.y * QI4_0 + threadIdx.x / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); @@ -281,17 +328,19 @@ template static __device__ __forceinlin const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else - x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d; -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template +template static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); const int * x_qs = (const int *) x; @@ -300,7 +349,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( const half2 * y_ds = (const half2 *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -308,7 +357,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); @@ -321,32 +370,37 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)]; } - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl - (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_0], u, - x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl + (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u, + x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template static __device__ __forceinline__ void load_tiles_q4_1( +template static __device__ __forceinline__ void load_tiles_q4_1( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kbx = threadIdx.x / QI4_1; - const int kqsx = threadIdx.x % QI4_1; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI4_1; + const int kqsx = txi % QI4_1; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); @@ -355,20 +409,21 @@ template static __device__ __forceinlin const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b4(bxi->qs, kqsx); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F; #else - x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) { - int i = i0 + threadIdx.y * QI4_1 + threadIdx.x / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); @@ -376,17 +431,19 @@ template static __device__ __forceinlin const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; -#ifdef NEW_MMA_AVAILABLE - x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else - x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; -#endif // NEW_MMA_AVAILABLE + x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template +template static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); const int * x_qs = (const int *) x; @@ -395,7 +452,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( const half2 * y_ds = (const half2 *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -403,7 +460,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); @@ -416,32 +473,37 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)]; } - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl - (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_1], u, - x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl + (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u, + x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template static __device__ __forceinline__ void load_tiles_q5_0( +template static __device__ __forceinline__ void load_tiles_q5_0( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kbx = threadIdx.x / QI5_0; - const int kqsx = threadIdx.x % QI5_0; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI5_0; + const int kqsx = txi % QI5_0; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); @@ -450,7 +512,7 @@ template static __device__ __forceinlin const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx; const int ql = get_int_b2(bxi->qs, kqsx); - const int qh = get_int_b2(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0)); + const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx); int qs0 = (ql >> 0) & 0x0F0F0F0F; qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 @@ -466,21 +528,22 @@ template static __device__ __forceinlin qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; #else - x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; - x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } - const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) { - int i = i0 + threadIdx.y * QI5_0 + threadIdx.x / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); @@ -488,32 +551,37 @@ template static __device__ __forceinlin const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else - x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d; -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_q5_1( +template static __device__ __forceinline__ void load_tiles_q5_1( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kbx = threadIdx.x / QI5_1; - const int kqsx = threadIdx.x % QI5_1; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI5_1; + const int kqsx = txi % QI5_1; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); @@ -522,7 +590,7 @@ template static __device__ __forceinlin const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx; const int ql = get_int_b4(bxi->qs, kqsx); - const int qh = get_int_b4(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1)); + const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx); int qs0 = (ql >> 0) & 0x0F0F0F0F; qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 @@ -536,21 +604,22 @@ template static __device__ __forceinlin qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; #else - x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; - x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } - const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) { - int i = i0 + threadIdx.y * QI5_1 + threadIdx.x / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); @@ -558,32 +627,38 @@ template static __device__ __forceinlin const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; -#ifdef NEW_MMA_AVAILABLE - x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else - x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; -#endif // NEW_MMA_AVAILABLE + x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_q8_0( +template static __device__ __forceinline__ void load_tiles_q8_0( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_tile + 2*WARP_SIZE); + float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kbx = threadIdx.x / QI8_0; - const int kqsx = threadIdx.x % QI8_0; + // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp + constexpr int threads_per_row = 32; + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI8_0; + const int kqsx = txi % QI8_0; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); @@ -591,21 +666,22 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; -#ifdef NEW_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx); - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx); +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); #else - x_qs[i*(2*WARP_SIZE + 1) + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx); - x_qs[i*(2*WARP_SIZE + 1) + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx); -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); + x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } - const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0; + constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0/2) { - int i = i0 + threadIdx.y * (QI8_0/2) + threadIdx.x / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); @@ -613,17 +689,19 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else - x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; -#endif // NEW_MMA_AVAILABLE + x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template +template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); const int * x_qs = (const int *) x; @@ -632,7 +710,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const float * y_df = (const float *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -640,21 +718,76 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl - (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % WARP_SIZE], - x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]); + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_q8_1_impl + (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K], + x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)]); } } } } -template +template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + float dB; + const int j = j0 + tile_C::get_j(0); + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { + dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } else { + dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(l); + const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB; + } + } + } + } +#else typedef tile<16, 8, int> tile_A; typedef tile< 8, 8, int> tile_B; typedef tile<16, 8, int> tile_C; @@ -663,23 +796,23 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + 2*WARP_SIZE; + const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; const half2 * y_ds = (const half2 *) y; - tile_A A[ntx][WARP_SIZE/QI8_0]; - float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0]; + tile_A A[ntx][MMQ_TILE_NE_K/QI8_0]; + float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0]; const int i0 = (threadIdx.y/ntx)*rows_per_warp; #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { const int k0 = k00 + k01; load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); @@ -690,7 +823,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { const int k0 = k00 + k01; dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; @@ -701,7 +834,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { tile_B B; float dB[tile_C::ne/2]; @@ -730,11 +863,14 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( } } } +#endif // defined(AMD_MFMA_AVAILABLE) } -template +template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); const int * x_qs = (const int *) x; @@ -743,7 +879,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( const half2 * y_ds = (const half2 *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -751,45 +887,95 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl - (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], - x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_1_q8_1_impl + (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template +template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; - typedef tile<16, 8, int> tile_A; - typedef tile< 8, 8, int> tile_B; - typedef tile<16, 8, int> tile_C; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K; + const int * y_qs = (const int *) y + 4; + const half2 * y_dm = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(l); + float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l]; + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y; + } + } + } + } +#else + typedef tile<16, 8, int> tile_A; + typedef tile< 8, 8, int> tile_B; + typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE; + const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K; const int * y_qs = (const int *) y + 4; const half2 * y_dm = (const half2 *) y; - tile_A A[ntx][WARP_SIZE/QI8_1]; - float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1]; + tile_A A[ntx][MMQ_TILE_NE_K/QI8_1]; + float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1]; const int i0 = (threadIdx.y/ntx)*rows_per_warp; #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { const int k0 = k00 + k01; load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); @@ -800,7 +986,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { const int k0 = k00 + k01; dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]); @@ -811,7 +997,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { tile_B B; float2 dsB[tile_C::ne/2]; @@ -837,11 +1023,15 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( } } } +#endif // defined(AMD_MFMA_AVAILABLE) } -template +// Used for Q3_K, IQ2_S, and IQ2_XS +template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; const int * x_qs = (const int *) x; @@ -850,7 +1040,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( const float * y_df = (const float *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { const int k0 = k00 + k01; #pragma unroll @@ -858,23 +1048,73 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_16_q8_1_impl( - &x_qs[i*(2*WARP_SIZE + 1) + k0], + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_16_q8_1_impl( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], - &x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)], + &x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template +// Used for Q3_K, IQ2_S, and IQ2_XS: +template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + typedef tile<64, 2, int> tile_load; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B[1]; + load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B[0]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB; + } + } + } + } +#elif defined(NEW_MMA_AVAILABLE) typedef tile<16, 4, int> tile_A; typedef tile<16, 8, int> tile_A_8; @@ -885,10 +1125,10 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + WARP_SIZE*2; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; @@ -900,7 +1140,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { const int k0 = k00 + k01; load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); @@ -911,7 +1151,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { const int k0 = k00 + k01; dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4]; @@ -922,7 +1162,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { tile_B B[2]; float dB[tile_C::ne/2]; @@ -953,26 +1193,29 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( #else GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE } -template static __device__ __forceinline__ void load_tiles_q2_K( +template static __device__ __forceinline__ void load_tiles_q2_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % QI2_K; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K); + constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_K) { - int i = i0 + threadIdx.y*(WARP_SIZE/QI2_K) + threadIdx.x/QI2_K; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -988,11 +1231,11 @@ template static __device__ __forceinlin const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303; -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; #else - x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } const int sc_m = bxi->scales[kqsx]; @@ -1003,17 +1246,19 @@ template static __device__ __forceinlin const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); #endif // FAST_FP16_AVAILABLE -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik; #else - x_dm[i*(WARP_SIZE + 1) + kqsx] = x_dm_ik; -#endif // NEW_MMA_AVAILABLE + x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template +template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); const int * x_qs = (const int *) x; @@ -1030,7 +1275,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( } #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -1038,13 +1283,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; constexpr int ns = 2; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( - &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], - &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); } } @@ -1053,7 +1298,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop. // As a workaround 2 separate loops are used instead. #pragma unroll - for (int k01 = WARP_SIZE/2; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { + for (int k01 = MMQ_TILE_NE_K/2; k01 < MMQ_TILE_NE_K; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -1061,23 +1306,89 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; constexpr int ns = 1; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( - &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], - &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); } } } } -template +template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + typedef tile<64, 2, int> tile_load; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B[1]; + load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2; + const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0 + : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y + : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x); + + tile_C Cm; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tile_A A1; + A1.x[0] = 0x01010101; + A1.x[1] = 0x01010101; + mma(Cm, A1, B[0]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C Cd; + mma(Cd, A[n], B[0]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]); + float tmp = Cd.x[l]*dm.x; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tmp -= Cm.x[l]*dm.y; + } + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB; + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB; + } + } + } + } +#elif defined(NEW_MMA_AVAILABLE) typedef tile<16, 4, int> tile_A; typedef tile<16, 8, int> tile_A_8; @@ -1088,10 +1399,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2; + const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; @@ -1104,7 +1415,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { const int k0 = k00 + k01; load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); @@ -1118,7 +1429,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1/2) { const int k0 = k00 + k01; const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]); @@ -1141,7 +1452,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( } #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { tile_B B[2]; // Here load_generic is faster than load_ldmatrix. @@ -1149,7 +1460,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K); tile_C Cm[2]; - if (k01 >= WARP_SIZE * 3/4) { + if (k01 >= MMQ_TILE_NE_K * 3/4) { tile_A A1; A1.x[0] = 0x01010101; A1.x[1] = 0x01010101; @@ -1167,16 +1478,16 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1]; - if (k01 >= WARP_SIZE * 3/4) { + if (k01 >= MMQ_TILE_NE_K * 3/4) { tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1]; } - sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < MMQ_TILE_NE_K/2 ? dB[l%2].x : dB[l%2].y); } } } #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K * 3/4; k01 += QI8_1) { float2 sB[tile_C::ne/2]; #pragma unroll @@ -1199,27 +1510,31 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( #else GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE } -template static __device__ __forceinline__ void load_tiles_q3_K( +template static __device__ __forceinline__ void load_tiles_q3_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % QI3_K; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K); + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI3_K) { - int i = i0 + threadIdx.y * (WARP_SIZE/QI3_K) + threadIdx.x / QI3_K; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -1239,17 +1554,18 @@ template static __device__ __forceinlin const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k; #else - x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } + constexpr int rows_per_warp = warp_size / 4; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) { - int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4; if (need_check) { i = min(i, i_max); @@ -1257,7 +1573,7 @@ template static __device__ __forceinlin const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; - const int ksc = threadIdx.x % (WARP_SIZE/8); + const int ksc = threadIdx.x % 4; const int ksc_low = ksc % (QI3_K/8); const int shift_low = 4 * (ksc / (QI3_K/8)); @@ -1269,23 +1585,23 @@ template static __device__ __forceinlin const int sc = __vsubss4(sc_low | sc_high, 0x20202020); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) const int8_t * sc8 = (const int8_t *) ≻ const float d = bxi->d; #pragma unroll for (int l = 0; l < int(sizeof(int)); ++l) { - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l]; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l]; } #else - x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc; -#endif // NEW_MMA_AVAILABLE + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } -#ifndef NEW_MMA_AVAILABLE +#if !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)) #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) { - int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; if (need_check) { i = min(i, i_max); @@ -1295,12 +1611,14 @@ template static __device__ __forceinlin x_df[i] = bxi->d; } -#endif // NEW_MMA_AVAILABLE +#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)) } -template +template static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); const int * x_qs = (const int *) x; @@ -1310,7 +1628,7 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( const float * y_df = (const float *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -1318,13 +1636,13 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - const int8_t * scales = ((const int8_t *) (x_sc + i*(WARP_SIZE/8) + i/8)) + k0/4; + const int8_t * scales = ((const int8_t *) (x_sc + i*(MMQ_TILE_NE_K/8) + i/8)) + k0/4; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq( - &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales, + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q3_K_q8_1_impl_mmq( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales, x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); } } @@ -1341,72 +1659,85 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co ((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits } -template static __device__ __forceinline__ void load_tiles_q4_K( +template static __device__ __forceinline__ void load_tiles_q4_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); } const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; - const int qs0 = get_int_b4(bxi->qs, threadIdx.x); + const int qs0 = get_int_b4(bxi->qs, txi); -#ifdef NEW_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; #else - x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } -#ifdef NEW_MMA_AVAILABLE - +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + constexpr int rows_per_warp = warp_size / 2; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) { - int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { +#if defined(AMD_MFMA_AVAILABLE) + // Need if on AMD instead of % because warp_size == 64 + // This causes double work and throughput loss (MI300X) + // H100 loses about 100 t/s with 'if' condition over '%' + int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2; + if (i < mmq_y) { +#else + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; + { +#endif // defined(AMD_MFMA_AVAILABLE) + if (need_check) { + i = min(i, i_max); + } - if (need_check) { - i = min(i, i_max); - } + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + const int * scales = (const int *) bxi->scales; + const int ksc = threadIdx.x % 2; - const int * scales = (const int *) bxi->scales; - const int ksc = threadIdx.x % (WARP_SIZE/16); + const int sc32 = unpack_scales_q45_K(scales, ksc + 0); + const int m32 = unpack_scales_q45_K(scales, ksc + 2); - const int sc32 = unpack_scales_q45_K(scales, ksc + 0); - const int m32 = unpack_scales_q45_K(scales, ksc + 2); + const uint8_t * sc8 = (const uint8_t *) &sc32; + const uint8_t * m8 = (const uint8_t *) &m32; - const uint8_t * sc8 = (const uint8_t *) &sc32; - const uint8_t * m8 = (const uint8_t *) &m32; + const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); - const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); - -#pragma unroll - for (int l = 0; l < int(sizeof(int)); ++l) { - x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + #pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + } } } - #else - #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI4_K) { - int i = (i0 + threadIdx.y*QI4_K + threadIdx.x) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; if (need_check) { i = min(i, i_max); @@ -1416,30 +1747,32 @@ template static __device__ __forceinlin x_dm[i] = bxi->dm; } - + constexpr int rows_per_warp = warp_size / 4; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8); + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8); const int * scales = (const int *) bxi->scales; - const int ksc = threadIdx.x % (WARP_SIZE/8); + const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8); const int scales8 = unpack_scales_q45_K(scales, ksc); - x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; } -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } -template +template static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); const int * x_qs = (const int *) x; @@ -1449,7 +1782,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( const half2 * y_ds = (const half2 *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -1457,97 +1790,110 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - const uint8_t * sc = (const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/32] + 2*(k01/16); + const uint8_t * sc = (const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/32] + 2*(k01/16); - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq( - &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_K_q8_1_impl_mmq( + &x_qs[i*(MMQ_TILE_NE_K + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template static __device__ __forceinline__ void load_tiles_q5_K( +template static __device__ __forceinline__ void load_tiles_q5_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2); + half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); } const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; - const int ky = QR5_K*threadIdx.x; + const int ky = QR5_K*txi; - const int ql = get_int_b4(bxi->qs, threadIdx.x); + const int ql = get_int_b4(bxi->qs, txi); const int ql0 = (ql >> 0) & 0x0F0F0F0F; const int ql1 = (ql >> 4) & 0x0F0F0F0F; - const int qh = get_int_b4(bxi->qh, threadIdx.x % (QI5_K/4)); - const int qh0 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 0)) << 4) & 0x10101010; - const int qh1 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 1)) << 4) & 0x10101010; + const int qh = get_int_b4(bxi->qh, txi % (QI5_K/4)); + const int qh0 = ((qh >> (2 * (txi / (QI5_K/4)) + 0)) << 4) & 0x10101010; + const int qh1 = ((qh >> (2 * (txi / (QI5_K/4)) + 1)) << 4) & 0x10101010; - const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0; - const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4; + const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0; + const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4; -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1; #else - x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0; - x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } -#ifdef NEW_MMA_AVAILABLE - +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + constexpr int rows_per_warp = warp_size / 2; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) { - int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; - - const int * scales = (const int *) bxi->scales; - const int ksc = threadIdx.x % (WARP_SIZE/16); - - const int sc32 = unpack_scales_q45_K(scales, ksc + 0); - const int m32 = unpack_scales_q45_K(scales, ksc + 2); - - const uint8_t * sc8 = (const uint8_t *) &sc32; - const uint8_t * m8 = (const uint8_t *) &m32; - - const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); - -#pragma unroll - for (int l = 0; l < int(sizeof(int)); ++l) { - x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); - } - } - + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { +#if defined(AMD_MFMA_AVAILABLE) + // Need if on AMD instead of % because warp_size == 64 + // This causes double work and throughput loss (MI300X) + // H100 loses about 100 t/s with 'if' condition over '%' + int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2; + if (i < mmq_y) { #else + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; + { +#endif // defined(AMD_MFMA_AVAILABLE) + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + + const int * scales = (const int *) bxi->scales; + const int ksc = threadIdx.x % 2; + + const int sc32 = unpack_scales_q45_K(scales, ksc + 0); + const int m32 = unpack_scales_q45_K(scales, ksc + 2); + + const uint8_t * sc8 = (const uint8_t *) &sc32; + const uint8_t * m8 = (const uint8_t *) &m32; + + const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI5_K) { - int i = (i0 + threadIdx.y*QI5_K + threadIdx.x) % mmq_y; + for (int l = 0; l < int(sizeof(int)); ++l) { + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + } + } + } +#else +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; if (need_check) { i = min(i, i_max); @@ -1558,9 +1904,10 @@ template static __device__ __forceinlin x_dm[i] = bxi->dm; } + constexpr int rows_per_warp = warp_size / 4; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) { - int i = (i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8)) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; if (need_check) { i = min(i, i_max); @@ -1570,17 +1917,19 @@ template static __device__ __forceinlin const int * scales = (const int *) bxi->scales; - const int ksc = threadIdx.x % (WARP_SIZE/8); + const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8); const int scales8 = unpack_scales_q45_K(scales, ksc); - x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; } -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } -template +template static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); const int * x_qs = (const int *) x; @@ -1590,7 +1939,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( const half2 * y_ds = (const half2 *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -1598,36 +1947,42 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16); + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k00/32]) + 2*(k01/16); - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq( - &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q5_K_q8_1_impl_mmq( + &x_qs[i*(QR5_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template static __device__ __forceinline__ void load_tiles_q6_K( +template static __device__ __forceinline__ void load_tiles_q6_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); - int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); + int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); @@ -1635,67 +1990,67 @@ template static __device__ __forceinlin const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; - const int ql = get_int_b2(bxi->ql, threadIdx.x); + const int ql = get_int_b2(bxi->ql, txi); const int ql0 = (ql >> 0) & 0x0F0F0F0F; const int ql1 = (ql >> 4) & 0x0F0F0F0F; - const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (threadIdx.x / (QI6_K/2)) + threadIdx.x % (QI6_K/4)); - const int qh0 = ((qh >> ((threadIdx.x & 0x08) >> 2)) << 4) & 0x30303030; - const int qh1 = (qh >> ((threadIdx.x & 0x08) >> 2)) & 0x30303030; + const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (txi / (QI6_K/2)) + txi % (QI6_K/4)); + const int qh0 = ((qh >> ((txi & 0x08) >> 2)) << 4) & 0x30303030; + const int qh1 = (qh >> ((txi & 0x08) >> 2)) & 0x30303030; - const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0; - const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2; + const int kq0 = 2*txi - txi % (QI6_K/2) + 0; + const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2; -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020); #else - x_qs[i*(2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); - x_qs[i*(2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } - const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 - const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256 - #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) { - int i = (i0 + threadIdx.y * QI6_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd; + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d; #else - x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d; -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } + constexpr int rows_per_warp = warp_size / 4; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4; + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4; -#ifdef NEW_MMA_AVAILABLE - x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8)); +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8)); #else - x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8)); -#endif // NEW_MMA_AVAILABLE + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8)); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template +template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); const int * x_qs = (const int *) x; @@ -1705,7 +2060,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( const float * y_df = (const float *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -1713,23 +2068,74 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]); + const int8_t * sc = ((const int8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/16]); - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq( - &x_qs[i*(QR6_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, - x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q6_K_q8_1_impl_mmq( + &x_qs[i*(QR6_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, + x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template +template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + typedef tile<64, 2, int> tile_load; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B[1]; + load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B[0]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB; + } + } + } + } +#elif defined(NEW_MMA_AVAILABLE) typedef tile<16, 4, int> tile_A; typedef tile< 8, 4, int> tile_B; @@ -1739,11 +2145,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + WARP_SIZE*2; - const int * x_sc = (const int *) x_df + WARP_SIZE/QI6_K; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; @@ -1756,7 +2162,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { const int k0 = k00 + k01; load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); @@ -1764,7 +2170,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( } #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 16) { const int k0 = k00 + k01; #pragma unroll @@ -1794,7 +2200,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( float tmp[ntx][tile_C::ne] = {{0.0f}}; #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { tile_B B[2]; float dB[tile_C::ne/2]; @@ -1833,27 +2239,32 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( #else GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE } -template static __device__ __forceinline__ void load_tiles_iq4_nl( +template static __device__ __forceinline__ void load_tiles_iq4_nl( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kbx = threadIdx.x / QI4_NL; - const int kqsx = threadIdx.x % QI4_NL; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI4_NL; + const int kqsx = txi % QI4_NL; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); @@ -1863,22 +2274,24 @@ template static __device__ __forceinlin const int aux_q4 = get_int_b2(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4); - const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; -#ifdef NEW_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; + const int k0 = kbx * (2 * QI4_NL) + kqsx; + +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y; #else - x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; - x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL; + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) { - int i = i0 + threadIdx.y * QI4_NL + threadIdx.x / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); @@ -1886,31 +2299,35 @@ template static __device__ __forceinlin const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); #else - x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d); -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_iq2_xxs( +template static __device__ __forceinline__ void load_tiles_iq2_xxs( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % (QI2_XXS/2); + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XXS/2)) { - int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XXS) + threadIdx.x/(QI2_XXS/2); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -1933,42 +2350,46 @@ template static __device__ __forceinlin const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1; #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid0; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid1; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } const int ls = aux32 >> 28; const float d = bxi->d; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; #else - x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/4; -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_iq2_xs( +template static __device__ __forceinline__ void load_tiles_iq2_xs( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % (QI2_XS/2); + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XS/2)) { - int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XS) + threadIdx.x/(QI2_XS/2); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -1987,44 +2408,48 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else - x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; - x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // NEW_MMA_AVAILABLE + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_iq2_s( +template static __device__ __forceinline__ void load_tiles_iq2_s( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % (QI2_S/2); + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_S/2)) { - int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_S) + threadIdx.x/(QI2_S/2); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -2050,44 +2475,48 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else - x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; - x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // NEW_MMA_AVAILABLE + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_iq3_xxs( +template static __device__ __forceinline__ void load_tiles_iq3_xxs( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % (QI3_XXS/2); + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_XXS/2)) { - int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_XXS) + threadIdx.x/(QI3_XXS/2); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -2108,42 +2537,46 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h; #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } const int ls = aux32 >> 28; const float d = bxi->d; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; #else - x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/2; -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_iq3_s( +template static __device__ __forceinline__ void load_tiles_iq3_s( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % (QI3_S/2); + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_S/2)) { - int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_S) + threadIdx.x/(QI3_S/2); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -2171,42 +2604,46 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h; #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid_l; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid_h; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F); const float d = bxi->d; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; #else - x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = ls*d; -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_iq1_s( +template static __device__ __forceinline__ void load_tiles_iq1_s( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2); + half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % QI1_S; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S); + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI1_S) { - int i = i0 + threadIdx.y*(WARP_SIZE/QI1_S) + threadIdx.x/QI1_S; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -2226,66 +2663,71 @@ template static __device__ __forceinlin const int grid0 = (grid >> 0) & 0x0F0F0F0F; const int grid1 = (grid >> 4) & 0x0F0F0F0F; -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1; #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid0; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid1; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1); const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); -#ifdef NEW_MMA_AVAILABLE - x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); #else - x_ds[i*(WARP_SIZE/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); -#endif // NEW_MMA_AVAILABLE + x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_iq4_xs( +template static __device__ __forceinline__ void load_tiles_iq4_xs( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kbx = 0; // threadIdx.x / QI4_XS - const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS); + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); } - const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride + kbx; + const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride; const int aux_q4 = get_int_b4(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4); - const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; -#ifdef NEW_MMA_AVAILABLE + const int k0 = 8 * (kqsx / 4) + kqsx % 4; + +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; #else - x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; - x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } + constexpr int rows_per_warp = warp_size / 8; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / (MMQ_TILE_NE_K/4); if (need_check) { i = min(i, i_max); @@ -2298,18 +2740,21 @@ template static __device__ __forceinlin const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F) | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4); -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); #else - x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template +template static __device__ __forceinline__ void mmq_write_back_dp4a( const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst, const int stride, const int i_max, const int j_max) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -2319,32 +2764,40 @@ static __device__ __forceinline__ void mmq_write_back_dp4a( } #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; if (need_check && i > i_max) { continue; } - dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; } } } -template +template static __device__ __forceinline__ void mmq_write_back_mma( const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst, const int stride, const int i_max, const int j_max) { - typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int nwarps = mmq_get_nwarps_device(); + +#if defined(AMD_MFMA_AVAILABLE) + constexpr int tileC_IJ = mmq_get_granularity_device(0); + typedef tile tile_C; + constexpr int rows_per_warp = granularity; +#else + typedef tile<16, 8, int> tile_C; constexpr int rows_per_warp = 2 * granularity; +#endif constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); -#ifdef NEW_MMA_AVAILABLE +#if defined(NEW_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y"); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { @@ -2372,179 +2825,181 @@ static __device__ __forceinline__ void mmq_write_back_mma( // ------------------------------------------------------------------------------------------------------------------------------------- -template +template struct mmq_type_traits; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -template +template static __device__ __forceinline__ void mul_mat_q_process_tile( const char * __restrict__ x, const int offset_x, const int * __restrict__ y, const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup, const int stride_row_x, const int ncols_y, const int stride_col_dst, const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) { + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int nwarps = mmq_get_nwarps_device(); constexpr int qk = ggml_cuda_type_traits::qk; constexpr int mmq_y = get_mmq_y_device(); - constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; + constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; extern __shared__ int data_mul_mat_q[]; int * tile_y = data_mul_mat_q + mmq_x; - int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE); + int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size); -#ifdef NEW_MMA_AVAILABLE - constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; - constexpr mmq_write_back_t write_back = mmq_write_back_mma; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; + constexpr mmq_write_back_t write_back = mmq_write_back_mma; #else - constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_dp4a; - constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; -#endif // NEW_MMA_AVAILABLE + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_dp4a; + constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) constexpr int blocks_per_iter = MMQ_ITER_K / qk; - float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; + float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x); @@ -2552,8 +3007,8 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( { const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int)); #pragma unroll - for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) { - int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x; + for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) { + int l = l0 + threadIdx.y*warp_size + threadIdx.x; tile_y[l] = by0[l]; } @@ -2568,8 +3023,8 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( { const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int)); #pragma unroll - for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) { - int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x; + for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) { + int l = l0 + threadIdx.y*warp_size + threadIdx.x; tile_y[l] = by0[l]; } @@ -2577,7 +3032,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( __syncthreads(); - vec_dot(tile_x, tile_y, sum, WARP_SIZE); + vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K); __syncthreads(); } @@ -2592,16 +3047,16 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( // The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598 -template +template #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) - __launch_bounds__(WARP_SIZE*nwarps, 2) + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) #endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) #else #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA - __launch_bounds__(WARP_SIZE*nwarps, 1) + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1) #else - __launch_bounds__(WARP_SIZE*nwarps, 2) + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) static __global__ void mul_mat_q( @@ -2617,6 +3072,9 @@ static __global__ void mul_mat_q( return; } + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int qk = ggml_cuda_type_traits::qk; constexpr int mmq_y = get_mmq_y_device(); @@ -2628,10 +3086,10 @@ static __global__ void mul_mat_q( // For MoE the correct indices are loaded from ids_dst. extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory. #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { - const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; - if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) { + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { break; } @@ -2640,7 +3098,7 @@ static __global__ void mul_mat_q( __syncthreads(); // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: -#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA +#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA { const int wt = blockIdx.z / nchannels_y; const int zt = blockIdx.z - wt*nchannels_y; @@ -2668,10 +3126,10 @@ static __global__ void mul_mat_q( // __syncthreads(); // There is no previous tile that could cause a race condition. #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { - const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; - if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) { + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { break; } @@ -2689,12 +3147,12 @@ static __global__ void mul_mat_q( const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = false; - mul_mat_q_process_tile + mul_mat_q_process_tile (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); return; } -#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA +#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA const int64_t blocks_per_ne00 = ncols_x / qk; constexpr int blocks_per_iter = MMQ_ITER_K / qk; @@ -2746,10 +3204,10 @@ static __global__ void mul_mat_q( __syncthreads(); #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { - const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; - if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) { + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { break; } @@ -2767,7 +3225,7 @@ static __global__ void mul_mat_q( const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. - mul_mat_q_process_tile + mul_mat_q_process_tile (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); @@ -2813,10 +3271,10 @@ static __global__ void mul_mat_q( // The memory layout for the fixup buffer is always contiguous, therefore reset ids: __syncthreads(); #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { - const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; - if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) { + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { break; } @@ -2834,13 +3292,13 @@ static __global__ void mul_mat_q( const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. - mul_mat_q_process_tile + mul_mat_q_process_tile (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); } -template +template static __global__ void mul_mat_q_stream_k_fixup( const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst, @@ -2850,7 +3308,10 @@ static __global__ void mul_mat_q_stream_k_fixup( constexpr int blocks_per_iter = MMQ_ITER_K / qk; const int64_t blocks_per_ne00 = ncols_x / qk; - float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; const int nty = (nrows_x + mmq_y - 1) / mmq_y; @@ -2894,10 +3355,10 @@ static __global__ void mul_mat_q_stream_k_fixup( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; + sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; } } @@ -2938,14 +3399,14 @@ static __global__ void mul_mat_q_stream_k_fixup( } #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; if (need_check && i > i_max) { continue; } - dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; } } return; @@ -2956,7 +3417,7 @@ static __global__ void mul_mat_q_stream_k_fixup( const int col_high = expert_bounds[zt + 1]; const int col_diff = col_high - col_low; - for (int j = threadIdx.y*WARP_SIZE + threadIdx.x; j < mmq_x; j += nwarps*WARP_SIZE) { + for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) { ids_dst_shared[j] = ids_dst[col_low + j]; } __syncthreads(); @@ -2976,14 +3437,14 @@ static __global__ void mul_mat_q_stream_k_fixup( } #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; if (need_check && i > i_max) { continue; } - dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; } } } @@ -2997,13 +3458,13 @@ struct mmq_args { }; template -static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc) { +static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) { const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); const size_t nbs_ids = mmq_x*sizeof(int); - const size_t nbs_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); + const size_t nbs_x = (new_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq); - return nbs_ids + nbs_x + GGML_PAD(nbs_y, MMQ_NWARPS*WARP_SIZE*sizeof(int)); + return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int)); } template @@ -3011,14 +3472,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const int id = ggml_cuda_get_device(); const int cc = ggml_cuda_info().devices[id].cc; const int nsm = ggml_cuda_info().devices[id].nsm; + const int warp_size = ggml_cuda_info().devices[id].warp_size; + const int nwarps = mmq_get_nwarps_host(cc); const int mmq_y = get_mmq_y_host(cc); - const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1); + const dim3 block_dims(warp_size, nwarps, 1); - const int nbytes_shared = mmq_get_nbytes_shared(mmq_x, mmq_y, cc); + const int nbytes_shared = mmq_get_nbytes_shared(mmq_x, mmq_y, cc, warp_size, nwarps); - CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); - CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); const int nty = (args.nrows_x + mmq_y - 1) / mmq_y; const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x; @@ -3033,14 +3496,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a if (!args.use_stream_k) { if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; - mul_mat_q<<>> + mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); } else { constexpr bool need_check = true; - mul_mat_q<<>> + mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, @@ -3060,8 +3523,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; - - mul_mat_q<<>> + mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, @@ -3071,13 +3533,12 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a return; } - mul_mat_q_stream_k_fixup<<>> + mul_mat_q_stream_k_fixup<<>> (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst); } else { constexpr bool need_check = true; - - mul_mat_q<<>> + mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, @@ -3087,7 +3548,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a return; } - mul_mat_q_stream_k_fixup<<>> + mul_mat_q_stream_k_fixup<<>> (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst); } @@ -3095,9 +3556,11 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a template void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { - const int id = ggml_cuda_get_device(); - const int cc = ggml_cuda_info().devices[id].cc; - const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + const int warp_size = ggml_cuda_info().devices[id].warp_size; + const int nwarps = mmq_get_nwarps_host(cc); const int mmq_x_max = get_mmq_x_max_host(cc); const int mmq_y = get_mmq_y_host(cc); @@ -3108,7 +3571,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) { const int granularity = mmq_get_granularity_host(mmq_x, cc); - if (mmq_x % granularity != 0 || mmq_get_nbytes_shared(mmq_x, mmq_y, cc) > smpbo) { + if (mmq_x % granularity != 0 || mmq_get_nbytes_shared(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) { continue; } diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index b2acdf855..079834364 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -44,6 +44,9 @@ static __global__ void k_set_rows_quant( block_type * dst_block = dst_row_ptr + i00 / qk; quantize_func(src_block, dst_block); + + GGML_UNUSED(ne10); + GGML_UNUSED(ne13); } // Template dispatch function for quantized set_rows diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 184d445f5..56e59a058 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -160,7 +160,19 @@ #endif #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) -#define CDNA +#define CDNA // For the entire family +#endif + +#if defined(__gfx942__) +#define CDNA3 +#endif + +#if defined(__gfx90a__) +#define CDNA2 +#endif + +#if defined(__gfx908__) +#define CDNA1 #endif #if defined(__GFX12__) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index b7b3fc49a..8424464d8 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -528,6 +528,7 @@ typedef struct { int64_t n_group; int64_t n_seq_tokens; int64_t n_seqs; + int64_t s_off; uint64_t nb01; uint64_t nb02; uint64_t nb03; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index a91b5bbbf..47961c4aa 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -3141,6 +3141,7 @@ static int ggml_metal_encode_node( /*.n_group =*/ n_group, /*.n_seq_tokens =*/ n_seq_tokens, /*.n_seqs =*/ n_seqs, + /*.s_off =*/ ggml_nelements(src1) * sizeof(float), /*.nb01 =*/ nb01, /*.nb02 =*/ nb02, /*.nb03 =*/ nb03, @@ -3169,12 +3170,22 @@ static int ggml_metal_encode_node( [encoder setBuffer:id_dst offset:offs_dst atIndex:7]; [encoder setBytes:&args length:sizeof(args) atIndex:8]; + // One shared memory bucket for each simd group in the threadgroup + // NOTE: Metal kernels require the buffer size to be multiple of 16 bytes + // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength + if (d_state >= 32) { + GGML_ASSERT((int64_t)(d_state / 32) <= 32); + const int64_t shmem_size = 32; + GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup); + [encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0]; + } + if (ne30 == 1) { // Mamba-2 - [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)]; } else { GGML_ASSERT(d_inner == 1); - [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)]; } } break; case GGML_OP_RWKV_WKV6: diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f62b9ad54..99a453090 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1823,10 +1823,16 @@ kernel void kernel_ssm_scan_f32( device const void * src5, device const void * src6, device float * dst, + threadgroup float * shared [[threadgroup(0)]], constant ggml_metal_kargs_ssm_scan & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgptg[[simdgroups_per_threadgroup]], + uint3 tgpg[[threadgroups_per_grid]]) { + + const int64_t i0 = tpitg.x; const int64_t i1 = 0; const int64_t ir = tgpig.x; // current head const int64_t i3 = tgpig.y; // current seq @@ -1841,41 +1847,88 @@ kernel void kernel_ssm_scan_f32( const int64_t ng = args.n_group; const int64_t n_t = args.n_seq_tokens; - const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float); + const int64_t s_off = args.s_off; device const int32_t * ids = (device const int32_t *) src6; - device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); - device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); + device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + const int64_t i = i0 + i1*nc; + float s0 = s0_buff[i]; + float s = s_buff[i]; + + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); + device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); + device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); + device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43); + device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53); + device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh} - device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} + device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} + device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} + device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; - float sumf = 0.0f; - for (int64_t i0 = 0; i0 < nc; ++i0) { - const int64_t i = i0 + i1*nc; - const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); - sumf += state * C[i0]; - s[i] = state; + const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); + s = state; + + // Parallel sum: This relies on the fact that this kernel will be + // dispatched with each threadgroup having (d_state, 1, 1) threads which + // are subdivided into SIMD groups of size `sgptg`. The goal is to + // compute y = sum({state * C[i] for i in range(d_state)}). + // To parallelize this effectively, we first use simd_sum over each SIMD + // group to compute the sum of each SIMD group, then place the result in + // the SIMD group's indexed bucket in the shared memory. We then sum + // over the individual group sums to compute the final sum. + + // Computed for each thread + float sumf = state * C[i0]; + + // Sum the threads in the simd group => simd sum + sumf = simd_sum(sumf); + + if (sgptg > 1) { + + // Once per simd group, place the group sum into the shared buffer + if (tiisg == 0) { + shared[sgitg] = sumf; + } + + // Wait for all threads in the threadgroup to reach this point. This + // ensures that all elements of the shared buffer are populated with the + // sum of the individual simd groups. + threadgroup_barrier(mem_flags::mem_threadgroup); + + // For simd group 0 at indices < num simd groups, extract the shared + // simd sum + sumf = 0.0f; + if (sgitg == 0) { + if (tiisg < sgptg) { + sumf = shared[tiisg]; + } + sumf = simd_sum(sumf); + if (tiisg == 0) { + y[0] = sumf; + } + } + } else if (tiisg == 0) { + y[0] = sumf; } - y[0] = sumf; - // recurse s0 = s; } + + // Assign the final state to the output buffer + s_buff[i] = s; } // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part -// TODO: optimize (e.g. by parallelizing over d_state) kernel void kernel_ssm_scan_f32_group( device const void * src0, device const void * src1, @@ -1885,10 +1938,16 @@ kernel void kernel_ssm_scan_f32_group( device const void * src5, device const void * src6, device float * dst, + threadgroup float * shared [[threadgroup(0)]], constant ggml_metal_kargs_ssm_scan & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgptg[[simdgroups_per_threadgroup]], + uint3 tgpg[[threadgroups_per_grid]]) { + + const int64_t i0 = tpitg.x; const int64_t i1 = tgpig.x; const int64_t ir = tgpig.y; // current head const int64_t i3 = tgpig.z; // current seq @@ -1903,38 +1962,81 @@ kernel void kernel_ssm_scan_f32_group( const int64_t ng = args.n_group; const int64_t n_t = args.n_seq_tokens; - const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float); + const int64_t s_off = args.s_off; device const int32_t * ids = (device const int32_t *) src6; - device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); - device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); + device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + const int64_t i = i0 + i1*nc; + float s0 = s0_buff[i]; + float s = s_buff[i]; + + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} + device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); + device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); + device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43); + device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53); + device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} - device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} + device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} + device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} + device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; const float dA = exp(dt_soft_plus * A[0]); - float sumf = 0.0f; - for (int64_t i0 = 0; i0 < nc; ++i0) { - const int64_t i = i0 + i1*nc; - const float state = (s0[i] * dA) + (B[i0] * x_dt); - sumf += state * C[i0]; - s[i] = state; + const float state = (s0 * dA) + (B[i0] * x_dt); + s = state; + + // Parallel sum: This relies on the fact that this kernel will be + // dispatched with each threadgroup having (d_state, 1, 1) threads which + // are subdivided into SIMD groups of size `sgptg`. The goal is to + // compute y = sum({state * C[i] for i in range(d_state)}). + // To parallelize this effectively, we first use simd_sum over each SIMD + // group to compute the sum of each SIMD group, then place the result in + // the SIMD group's indexed bucket in the shared memory. We then sum + // over the individual group sums to compute the final sum. + + // Computed for each thread + float sumf = state * C[i0]; + + // Sum the threads in the simd group => simd sum + sumf = simd_sum(sumf); + + // Once per simd group, place the group sum into the shared buffer + if (tiisg == 0) { + shared[sgitg] = sumf; } - y[0] = sumf; + // Wait for all threads in the threadgroup to reach this point. This + // ensures that all elements of the shared buffer are populated with the + // sum of the individual simd groups. + threadgroup_barrier(mem_flags::mem_threadgroup); + + // For simd group 0 at indices < num simd groups, extract the shared + // simd sum + sumf = 0.0f; + if (sgitg == 0) { + if (tiisg < sgptg) { + sumf = shared[tiisg]; + } + sumf = simd_sum(sumf); + if (tiisg == 0) { + y[0] = sumf; + } + } // recurse s0 = s; } + + // Assign the final state to the output buffer + s_buff[i] = s; } kernel void kernel_rwkv_wkv6_f32( diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 7dd36b2f7..4f74824b5 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6656,20 +6656,18 @@ static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgr static void ggml_graph_dump_dot_node_edge(FILE * fp, const struct ggml_cgraph * gb, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) { struct ggml_tensor * gparent = ggml_graph_get_parent(gb, node); struct ggml_tensor * gparent0 = ggml_graph_get_parent(gb, parent); - fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"%s\"; ]\n", + fprintf(fp, " \"%p\" -> \"%p\" [ arrowhead = %s; style = %s; label = \"%s\"; ]\n", gparent0 ? (void *) gparent0 : (void *) parent, - gparent0 ? "g" : "x", gparent ? (void *) gparent : (void *) node, - gparent ? "g" : "x", gparent ? "empty" : "vee", gparent ? "dashed" : "solid", label); } static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) { - fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"%s\"; ]\n", - (void *) parent, "x", - (void *) node, "x", + fprintf(fp, " \"%p\" -> \"%p\" [ label = \"%s\"; ]\n", + (void *) parent, + (void *) node, label); } diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 99adb3675..cfe86a5bb 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -105,7 +105,7 @@ llama_context::llama_context( { const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); - const bool supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : false; + supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : false; if (!supports_set_rows && !cparams.kv_unified) { LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__); @@ -899,6 +899,12 @@ int llama_context::encode(const llama_batch & batch_inp) { } } + if (!supports_set_rows) { + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + ggml_backend_sched_reset(sched.get()); + } + // TODO: hacky solution if (model.arch == LLM_ARCH_T5 && t_embd) { //cross.t_embd = t_embd; @@ -1229,6 +1235,12 @@ int llama_context::decode(const llama_batch & batch_inp) { // wait for the computation to finish (automatically done when obtaining the model output) //synchronize(); + if (!supports_set_rows) { + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + ggml_backend_sched_reset(sched.get()); + } + return 0; } diff --git a/src/llama-context.h b/src/llama-context.h index fdbe61207..5c3a1c098 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -287,6 +287,10 @@ private: bool has_evaluated_once = false; + // env: LLAMA_SET_ROWS (temporary) + // ref: https://github.com/ggml-org/llama.cpp/pull/14285 + bool supports_set_rows = false; + // perf mutable int64_t t_start_us = 0; mutable int64_t t_load_us = 0; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index c422cd7be..ec7fd6a42 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -98,7 +98,7 @@ struct llama_hparams { float rope_freq_scale_train; float rope_freq_scale_train_swa; uint32_t n_ctx_orig_yarn; - float rope_yarn_log_mul; + float rope_yarn_log_mul = 0.0f; std::array rope_sections; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c98453520..be7ec183c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1374,7 +1374,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // that have no expert_gating_func model parameter set hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; } - ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); switch (hparams.n_layer) { case 27: type = LLM_TYPE_16B; break; @@ -16291,7 +16291,7 @@ private: { // PLaMo-2 uses combined QKV tensor ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); - cb(qkv, "qkv", il); + cb(qkv, "wqkv", il); // split QKV tensor into Q, K, V const int64_t n_embd_head_q = hparams.n_embd_head_k; @@ -16331,7 +16331,7 @@ private: ext_factor, attn_factor, beta_fast, beta_slow ); - cur = build_attn(inp, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f, il); + cur = build_attn(inp, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head_v)), il); } cb(cur, "attn_out", il); @@ -16406,8 +16406,9 @@ private: ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv, ggml_view_1d(ctx0, conv_states_all, - (d_conv - 1)*(d_inner)*(n_seqs), - kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs), + kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all)))); + cb(conv_states_all, "mamba_conv1d_state", il); // 1D convolution x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); @@ -16470,9 +16471,9 @@ private: // store last states ggml_build_forward_expand(gf, ggml_cpy(ctx0, - ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]), - ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, - kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + ggml_view_1d(ctx0, y_ssm, n_heads*head_dim*d_state*n_seqs, n_heads*head_dim*n_seq_tokens*n_seqs*ggml_element_size(y_ssm)), + ggml_view_1d(ctx0, ssm_states_all, n_heads*head_dim*d_state*n_seqs, kv_head*n_seqs*n_heads*head_dim*d_state*ggml_element_size(ssm_states_all)))); + cb(ssm_states_all, "mamba_ssm_states", il); ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0); cb(y, "mamba_y_view", il); diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 442056443..40ae23151 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2366,7 +2366,7 @@ struct clip_model_loader { // create data context struct ggml_init_params params = { - /*.mem_size =*/ (gguf_get_n_tensors(ctx_gguf.get()) + 1) * ggml_tensor_overhead(), + /*.mem_size =*/ static_cast(gguf_get_n_tensors(ctx_gguf.get()) + 1) * ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, };