From a818f3028d1497a51cb2b8eb7d993ad58784940e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 24 Jun 2024 17:43:42 +0200 Subject: [PATCH 1/5] CUDA: use MMQ instead of cuBLAS by default (#8075) --- CMakeLists.txt | 15 ++++--- Makefile | 3 ++ README.md | 5 ++- ggml-cuda.cu | 96 ++++++++++++++++---------------------------- ggml-cuda/common.cuh | 36 ++--------------- ggml-cuda/mmq.cu | 36 +++++++++++++++-- ggml-cuda/mmq.cuh | 53 ++++++++++++++++-------- ggml-cuda/mmvq.cuh | 2 + 8 files changed, 124 insertions(+), 122 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 49ba45356..1acf4bb08 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -102,7 +102,8 @@ option(LLAMA_LLAMAFILE "llama: use llamafile SGEMM" option(LLAMA_CUDA "llama: use CUDA" OFF) option(LLAMA_CUBLAS "llama: use CUDA (deprecated, use LLAMA_CUDA)" OFF) option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF) -option(LLAMA_CUDA_FORCE_MMQ "llama: use mmq kernels instead of cuBLAS" OFF) +option(LLAMA_CUDA_FORCE_MMQ "llama: always use mmq kernels instead of cuBLAS" OFF) +option(LLAMA_CUDA_FORCE_CUBLAS "llama: always use cuBLAS instead of mmq kernels" OFF) set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels") option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF) @@ -416,13 +417,14 @@ if (LLAMA_CUDA) if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) # 52 == lowest CUDA 12 standard - # 60 == f16 CUDA intrinsics + # 60 == FP16 CUDA intrinsics # 61 == integer CUDA intrinsics - # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster + # 70 == FP16 tensor cores + # 75 == int8 tensor cores if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16) - set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics + set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75") else() - set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics + set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75") #set(CMAKE_CUDA_ARCHITECTURES "OFF") # use this to compile much faster, but only F16 models work endif() endif() @@ -447,6 +449,9 @@ if (LLAMA_CUDA) if (LLAMA_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ) endif() + if (LLAMA_CUDA_FORCE_CUBLAS) + add_compile_definitions(GGML_CUDA_FORCE_CUBLAS) + endif() if (LLAMA_CUDA_NO_VMM) add_compile_definitions(GGML_CUDA_NO_VMM) endif() diff --git a/Makefile b/Makefile index 3aad77394..f6e8eb73e 100644 --- a/Makefile +++ b/Makefile @@ -537,6 +537,9 @@ endif # LLAMA_CUDA_FORCE_DMMV ifdef LLAMA_CUDA_FORCE_MMQ MK_NVCCFLAGS += -DGGML_CUDA_FORCE_MMQ endif # LLAMA_CUDA_FORCE_MMQ +ifdef LLAMA_CUDA_FORCE_CUBLAS + MK_NVCCFLAGS += -DGGML_CUDA_FORCE_CUBLAS +endif # LLAMA_CUDA_FORCE_CUBLAS ifdef LLAMA_CUDA_DMMV_X MK_NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X) else diff --git a/README.md b/README.md index 40793c8ea..a54ee3951 100644 --- a/README.md +++ b/README.md @@ -510,8 +510,9 @@ Building the program with BLAS support may lead to some performance improvements |--------------------------------|------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. | | LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | - | LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. | - | LLAMA_CUDA_FORCE_MMQ | Boolean | false | Force the use of dequantization + matrix multiplication kernels instead of leveraging Math libraries. | | + | LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. | + | LLAMA_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, RDNA3). Speed for large batch sizes will be worse but VRAM consumption will be lower. | + | LLAMA_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models | | LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. | | LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | | LLAMA_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. | diff --git a/ggml-cuda.cu b/ggml-cuda.cu index f914efd71..2dda03924 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -152,16 +152,16 @@ static ggml_cuda_device_info ggml_cuda_init() { GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES); int64_t total_vram = 0; -#if defined(GGML_CUDA_FORCE_MMQ) - GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__); +#ifdef GGML_CUDA_FORCE_MMQ + GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__); #else - GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__); -#endif -#if defined(CUDA_USE_TENSOR_CORES) - GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: yes\n", __func__); + GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__); +#endif // GGML_CUDA_FORCE_MMQ +#ifdef GGML_CUDA_FORCE_CUBLAS + GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__); #else - GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: no\n", __func__); -#endif + GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__); +#endif // GGML_CUDA_FORCE_CUBLAS GGML_CUDA_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count); for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; @@ -1873,9 +1873,17 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer); - int64_t min_compute_capability = INT_MAX; + bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 + && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1; + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 + && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; + bool use_mul_mat_q = ggml_is_quantized(src0->type) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; + + bool any_gpus_with_slow_fp16 = false; - bool any_pascal_with_slow_fp16 = false; if (split) { ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context; auto & tensor_split = buft_ctx->tensor_split; @@ -1885,55 +1893,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor continue; } - if (min_compute_capability > ggml_cuda_info().devices[id].cc) { - min_compute_capability = ggml_cuda_info().devices[id].cc; - } - if (ggml_cuda_info().devices[id].cc == 610) { - any_pascal_with_slow_fp16 = true; - } + const int cc = ggml_cuda_info().devices[id].cc; + use_mul_mat_vec_q = use_mul_mat_vec_q && cc >= MIN_CC_DP4A; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); } } else { - min_compute_capability = ggml_cuda_info().devices[ctx.device].cc; - any_pascal_with_slow_fp16 = ggml_cuda_info().devices[ctx.device].cc == 610; + const int cc = ggml_cuda_info().devices[ctx.device].cc; + use_mul_mat_vec_q = use_mul_mat_vec_q && cc >= MIN_CC_DP4A; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); } - // check data types and tensor shapes for custom matrix multiplication kernels: - bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1; - - bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; - - bool use_mul_mat_q = ggml_cuda_supports_mmq(src0->type) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - - const bool fp16_performance_good = min_compute_capability >= CC_RDNA1; - -#ifdef CUDA_USE_TENSOR_CORES - use_mul_mat_q = use_mul_mat_q && min_compute_capability < CC_RDNA3; -#endif // CUDA_USE_TENSOR_CORES - -#else - - // fp16 performance is good on Volta or newer and on P100 (compute capability 6.0) - const bool fp16_performance_good = min_compute_capability >= CC_PASCAL && !any_pascal_with_slow_fp16; - - // mmvq and mmq need the __dp4a instruction which on NVIDIA is only available for CC >= 6.1 - use_mul_mat_vec_q = use_mul_mat_vec_q && min_compute_capability >= MIN_CC_DP4A; - use_mul_mat_q = use_mul_mat_q && min_compute_capability >= MIN_CC_DP4A; - -#ifdef CUDA_USE_TENSOR_CORES - // when tensor cores are available, use them for large batch size - // ref: https://github.com/ggerganov/llama.cpp/pull/3776 - use_mul_mat_q = use_mul_mat_q && (!fp16_performance_good || src1->ne[1] <= MMQ_MAX_BATCH_SIZE); -#endif // CUDA_USE_TENSOR_CORES - -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - // if mmvq is available it's a better choice than dmmv: #ifndef GGML_CUDA_FORCE_DMMV use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; @@ -1947,21 +1918,22 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { - // KQ single-batch + if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { + // FP32 precision KQ single-batch for batch size 1 without FlashAttention ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst); - } else if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { - // KQV single-batch + } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + // FP32 precision KQV single-batch for batch size 1 without FlashAttention ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst); - } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || fp16_performance_good) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { - // KQ + KQV multi-batch - ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr); } else if (use_mul_mat_vec_q) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); } else if (use_mul_mat_q) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda); + } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) + && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + // KQ + KQV multi-batch without FlashAttention + ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); } else { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr); } diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 5c8662535..8d00db6c1 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -146,23 +146,6 @@ #define CC_RDNA2 (CC_OFFSET_AMD + 1030) #define CC_RDNA3 (CC_OFFSET_AMD + 1100) -// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication -// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant -// for large computational tasks. the drawback is that this requires some extra amount of VRAM: -// - 7B quantum model: +100-200 MB -// - 13B quantum model: +200-400 MB -// -//#define GGML_CUDA_FORCE_MMQ - -// TODO: improve this to be correct for more hardware -// for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores -#if !defined(GGML_CUDA_FORCE_MMQ) -#define CUDA_USE_TENSOR_CORES -#endif - -#define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels -#define MMQ_MAX_BATCH_SIZE 64 // max batch size to use MMQ kernels when tensor cores are available - #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses #if defined(_MSC_VER) @@ -343,15 +326,15 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int #define INT8_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING -static bool fast_fp16_available(const int cc) { +static constexpr bool fast_fp16_available(const int cc) { return cc >= CC_PASCAL && cc != 610; } -static bool fp16_mma_available(const int cc) { +static constexpr bool fp16_mma_available(const int cc) { return cc < CC_OFFSET_AMD && cc >= CC_VOLTA; } -static bool int8_mma_available(const int cc) { +static constexpr bool int8_mma_available(const int cc) { return cc < CC_OFFSET_AMD && cc >= CC_TURING; } @@ -643,19 +626,6 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI3_S; }; -static constexpr int get_mmq_x_max_host(int cc) { -#ifdef CUDA_USE_TENSOR_CORES - return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64; -#else - return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64; -#endif // CUDA_USE_TENSOR_CORES -} - -// Round rows to this value for --split-mode row: -static constexpr int get_mmq_y_host(int cc) { - return cc >= CC_VOLTA ? 128 : 64; -} - ////////////////////// struct ggml_cuda_device_info { diff --git a/ggml-cuda/mmq.cu b/ggml-cuda/mmq.cu index 6dbd85fef..0308beacc 100644 --- a/ggml-cuda/mmq.cu +++ b/ggml-cuda/mmq.cu @@ -69,7 +69,13 @@ void ggml_cuda_op_mul_mat_q( GGML_UNUSED(src1_ddf_i); } -bool ggml_cuda_supports_mmq(enum ggml_type type) { +bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { +#ifdef GGML_CUDA_FORCE_CUBLAS + return false; +#endif // GGML_CUDA_FORCE_CUBLAS + + bool mmq_supported; + switch (type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: @@ -81,8 +87,32 @@ bool ggml_cuda_supports_mmq(enum ggml_type type) { case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: - return true; + mmq_supported = true; + break; default: - return false; + mmq_supported = false; + break; } + + if (!mmq_supported) { + return false; + } + + if (int8_mma_available(cc)) { + return true; + } + + if (cc < MIN_CC_DP4A) { + return false; + } + +#ifdef GGML_CUDA_FORCE_MMQ + return true; +#endif //GGML_CUDA_FORCE_MMQ + + if (cc < CC_OFFSET_AMD) { + return cc < CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + } + + return cc < CC_RDNA3 || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index 0f7f8ae51..1fc948be5 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -7,6 +7,8 @@ #include #include +#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. + typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride); typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0); typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max); @@ -24,25 +26,42 @@ struct tile_x_sizes { int sc; }; -// get_mmq_x_max_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row - -static constexpr __device__ int get_mmq_x_max_device() { -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - return 64; +static constexpr int get_mmq_x_max_host(const int cc) { + return int8_mma_available(cc) ? 128 : +#ifdef GGML_CUDA_FORCE_MMQ + cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64; #else -#if __CUDA_ARCH__ >= CC_VOLTA -#ifdef CUDA_USE_TENSOR_CORES - return MMQ_MAX_BATCH_SIZE; -#else - return 128; -#endif // CUDA_USE_TENSOR_CORES -#else - return 64; -#endif // __CUDA_ARCH__ >= CC_VOLTA -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64; +#endif // GGML_CUDA_FORCE_MMQ } -// get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row +static constexpr __device__ int get_mmq_x_max_device() { +#ifdef INT8_MMA_AVAILABLE + return 128; +#else // INT8_MMA_AVAILABLE + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + return 128; +#else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + +#if __CUDA_ARCH__ >= CC_VOLTA +#ifdef GGML_CUDA_FORCE_MMQ + return MMQ_DP4A_MAX_BATCH_SIZE; +#else // GGML_CUDA_FORCE_MMQ + return 128; +#endif // GGML_CUDA_FORCE_MMQ +#else // __CUDA_ARCH__ >= CC_VOLTA + + return 64; +#endif // __CUDA_ARCH__ >= CC_VOLTA + +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#endif // INT8_MMA_AVAILABLE +} + +static constexpr int get_mmq_y_host(const int cc) { + return int8_mma_available(cc) || cc >= CC_VOLTA ? 128 : 64; +} static constexpr __device__ int get_mmq_y_device() { #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) @@ -2590,4 +2609,4 @@ void ggml_cuda_op_mul_mat_q( const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); -bool ggml_cuda_supports_mmq(enum ggml_type type); +bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11); diff --git a/ggml-cuda/mmvq.cuh b/ggml-cuda/mmvq.cuh index 88c42c4b7..d9e42fdd6 100644 --- a/ggml-cuda/mmvq.cuh +++ b/ggml-cuda/mmvq.cuh @@ -1,5 +1,7 @@ #include "common.cuh" +#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. + void ggml_cuda_op_mul_mat_vec_q( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, From 3b099bcd9cbf2434f90cbe40eba6fa2189ed1d02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 24 Jun 2024 22:15:33 +0200 Subject: [PATCH 2/5] CUDA: fix MMQ writeback for int8 tensor cores (#8100) --- ggml-cuda/mmq.cuh | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index 1fc948be5..31fcbf139 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -2054,15 +2054,13 @@ static __device__ __forceinline__ void mmq_write_back_mma( static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y"); #endif // INT8_MMA_AVAILABLE - dst += (threadIdx.y % ntx) * mma_C::J*stride; - #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { - const int j = j0 + mma_C::get_j(l); + const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l); if (j > j_max) { continue; From 2df373ac40ea581ccca8a58c713f03ad9d4b658d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 25 Jun 2024 01:22:33 +0200 Subject: [PATCH 3/5] CUDA: fix matrix multiplication algorithm choice (#8102) --- ggml-cuda.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2dda03924..0acfda91d 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1924,16 +1924,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { // FP32 precision KQV single-batch for batch size 1 without FlashAttention ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst); + } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) + && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + // KQ + KQV multi-batch without FlashAttention + ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr); } else if (use_mul_mat_vec_q) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); } else if (use_mul_mat_q) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda); - } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) - && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { - // KQ + KQV multi-batch without FlashAttention - ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); } else { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr); } From 083bacce14c1aaf9976aa40e8266cdc25ac749d3 Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Tue, 25 Jun 2024 10:19:20 +0800 Subject: [PATCH 4/5] [SYCL] Re-enabled mul_mat_batched_sycl (#8095) --- ggml-sycl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index e5ddf4a34..db045336f 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -4620,7 +4620,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { // KQV single-batch ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst); - } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // KQ + KQV multi-batch ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { From f702a90e245499283d6de0b287701c723cda2a87 Mon Sep 17 00:00:00 2001 From: HatsuneMikuUwU33 <173229399+HatsuneMikuUwU33@users.noreply.github.com> Date: Tue, 25 Jun 2024 10:44:48 +0200 Subject: [PATCH 5/5] Update control vector help (#8104) --- common/common.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 1dc532651..0ca7b4430 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1538,9 +1538,11 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --lora FNAME", "apply LoRA adapter (implies --no-mmap)" }); options.push_back({ "*", " --lora-scaled FNAME S", "apply LoRA adapter with user defined scaling S (implies --no-mmap)" }); options.push_back({ "*", " --lora-base FNAME", "optional model to use as a base for the layers modified by the LoRA adapter" }); - options.push_back({ "*", " --control-vector FNAME", "add a control vector" }); + options.push_back({ "*", " --control-vector FNAME", "add a control vector\n" + "note: this argument can be repeated to add multiple control vectors" }); options.push_back({ "*", " --control-vector-scaled FNAME SCALE", - "add a control vector with user defined scaling SCALE" }); + "add a control vector with user defined scaling SCALE\n" + "note: this argument can be repeated to add multiple scaled control vectors" }); options.push_back({ "*", " --control-vector-layer-range START END", "layer range to apply the control vector(s) to, start and end inclusive" }); options.push_back({ "*", "-m, --model FNAME", "model path (default: models/$filename with filename from --hf-file\n"