CUDA: replace GGML_CUDA_F16 with CUDA arch checks (#15433)

This commit is contained in:
Johannes Gäßler 2025-08-20 16:58:49 +02:00 committed by GitHub
parent fec9519802
commit 7a6e91ad26
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 32 additions and 86 deletions

View file

@ -198,10 +198,9 @@ The environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1` can be used to enab
The following compilation options are also available to tweak performance:
| Option | Legal values | Default | Description |
|-------------------------------|------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|-------------------------------|------------------------|---------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| GGML_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, CDNA and RDNA3+). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. |
| GGML_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models |
| GGML_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
| GGML_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models. There may be issues with numerical overflows (except for CDNA and RDNA4) and memory use will be higher. Prompt processing may become faster on recent datacenter GPUs (the custom kernels were tuned primarily for RTX 3000/4000). |
| GGML_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. |
| GGML_CUDA_FA_ALL_QUANTS | Boolean | false | Compile support for all KV cache quantization type (combinations) for the FlashAttention CUDA kernels. More fine-grained control over KV cache size but compilation takes much longer. |

View file

@ -194,7 +194,7 @@ llama_print_timings: total time = 44411.01 ms / 377 tokens
## Orin compile and run
### compile
```sh
make GGML_CUDA=1 CUDA_DOCKER_ARCH=sm_87 GGML_CUDA_F16=1 -j 32
make GGML_CUDA=1 CUDA_DOCKER_ARCH=sm_87 -j 32
```
### run on Orin
### case 1

View file

@ -158,7 +158,6 @@ option(GGML_CUDA "ggml: use CUDA"
option(GGML_MUSA "ggml: use MUSA" OFF)
option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF)
option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF)
option(GGML_CUDA_F16 "ggml: use 16 bit floats for some calculations" OFF)
set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
"ggml: max. batch size for using peer access")
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)

View file

@ -24,12 +24,6 @@ if (CUDAToolkit_FOUND)
# for best performance and to also build real architectures for the most commonly used GPUs.
if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
set(CMAKE_CUDA_ARCHITECTURES "native")
elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real")
else()
set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real")
endif()
else()
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real")
@ -91,10 +85,6 @@ if (CUDAToolkit_FOUND)
add_compile_definitions(GGML_CUDA_NO_FA)
endif()
if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
add_compile_definitions(GGML_CUDA_F16)
endif()
if (GGML_CUDA_NO_PEER_COPY)
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
endif()

View file

@ -206,14 +206,6 @@ static const char * cu_get_error_str(CUresult err) {
#define GGML_CUDA_ASSUME(x)
#endif // CUDART_VERSION >= 11010
#ifdef GGML_CUDA_F16
typedef half dfloat; // dequantize float
typedef half2 dfloat2;
#else
typedef float dfloat; // dequantize float
typedef float2 dfloat2;
#endif // GGML_CUDA_F16
#if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
#define GGML_USE_VMM
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
@ -559,7 +551,7 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
#endif // CUDART_VERSION >= 12050
}
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);
static __device__ __forceinline__ float get_alibi_slope(
const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1

View file

@ -27,7 +27,7 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
const int64_t y_offset = qr == 1 ? 1 : qk/2;
// dequantize
dfloat2 v;
float2 v;
dequantize_kernel(vx, ib, iqs, v);
const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;

View file

@ -42,7 +42,7 @@ static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
#pragma unroll
for (int j = 0; j < QK8_0; j += 2) {
dfloat2 dq;
float2 dq;
dequantize_q8_0(cxi, 0, j, dq);
*(cdstf + j) = dq.x;
*(cdstf + j + 1) = dq.y;
@ -55,7 +55,7 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
#pragma unroll
for (int j = 0; j < qk/2; j++) {
dfloat2 dq;
float2 dq;
dequant(cxi, 0, j, dq);
*(cdstf + j) = dq.x;
*(cdstf + j + qk/2) = dq.y;

View file

@ -1,48 +1,37 @@
#include "common.cuh"
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_q4_0 * x = (const block_q4_0 *) vx;
const dfloat d = x[ib].d;
const float d = x[ib].d;
const int vui = x[ib].qs[iqs];
v.x = vui & 0xF;
v.y = vui >> 4;
#ifdef GGML_CUDA_F16
v = __hsub2(v, {8.0f, 8.0f});
v = __hmul2(v, {d, d});
#else
v.x = (v.x - 8.0f) * d;
v.y = (v.y - 8.0f) * d;
#endif // GGML_CUDA_F16
}
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_q4_1 * x = (const block_q4_1 *) vx;
const dfloat d = __low2half(x[ib].dm);
const dfloat m = __high2half(x[ib].dm);
const float2 dm = __half22float2(x[ib].dm);
const int vui = x[ib].qs[iqs];
v.x = vui & 0xF;
v.y = vui >> 4;
#ifdef GGML_CUDA_F16
v = __hmul2(v, {d, d});
v = __hadd2(v, {m, m});
#else
v.x = (v.x * d) + m;
v.y = (v.y * d) + m;
#endif // GGML_CUDA_F16
v.x = (v.x * dm.x) + dm.y;
v.y = (v.y * dm.x) + dm.y;
}
static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_q5_0 * x = (const block_q5_0 *) vx;
const dfloat d = x[ib].d;
const float d = x[ib].d;
uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh));
@ -53,20 +42,14 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
#ifdef GGML_CUDA_F16
v = __hsub2(v, {16.0f, 16.0f});
v = __hmul2(v, {d, d});
#else
v.x = (v.x - 16.0f) * d;
v.y = (v.y - 16.0f) * d;
#endif // GGML_CUDA_F16
}
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_q5_1 * x = (const block_q5_1 *) vx;
const dfloat d = __low2half(x[ib].dm);
const dfloat m = __high2half(x[ib].dm);
const float2 dm = __half22float2(x[ib].dm);
uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh));
@ -77,27 +60,18 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
#ifdef GGML_CUDA_F16
v = __hmul2(v, {d, d});
v = __hadd2(v, {m, m});
#else
v.x = (v.x * d) + m;
v.y = (v.y * d) + m;
#endif // GGML_CUDA_F16
v.x = (v.x * dm.x) + dm.y;
v.y = (v.y * dm.x) + dm.y;
}
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_q8_0 * x = (const block_q8_0 *) vx;
const dfloat d = x[ib].d;
const float d = x[ib].d;
v.x = x[ib].qs[iqs + 0];
v.y = x[ib].qs[iqs + 1];
#ifdef GGML_CUDA_F16
v = __hmul2(v, {d, d});
#else
v.x *= d;
v.y *= d;
#endif // GGML_CUDA_F16
}

View file

@ -32,7 +32,7 @@ static __global__ void k_get_rows(
const int y_offset = qr == 1 ? 1 : qk/2;
// dequantize
dfloat2 v;
float2 v;
dequantize_kernel(src0_row, ib, iqs, v);
dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);

View file

@ -3672,10 +3672,6 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
features.push_back({ "NO_PEER_COPY", "1" });
#endif
#ifdef GGML_CUDA_F16
features.push_back({ "F16", "1" });
#endif
#ifdef GGML_CUDA_USE_GRAPHS
features.push_back({ "USE_GRAPHS", "1" });
#endif

View file

@ -87,7 +87,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
}
#ifdef GGML_CUDA_F16
#ifdef FAST_FP16_AVAILABLE
const float2 tmp = __half22float2(__hmul2(dm4, ds8));
const float d4d8 = tmp.x;
const float m4s8 = tmp.y;
@ -96,7 +96,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
const float2 ds8f = __half22float2(ds8);
const float d4d8 = dm4f.x * ds8f.x;
const float m4s8 = dm4f.y * ds8f.y;
#endif // GGML_CUDA_F16
#endif // FAST_FP16_AVAILABLE
// scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
@ -158,7 +158,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
}
#ifdef GGML_CUDA_F16
#ifdef FAST_FP16_AVAILABLE
const float2 tmp = __half22float2(__hmul2(dm5, ds8));
const float d5d8 = tmp.x;
const float m5s8 = tmp.y;
@ -167,7 +167,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
const float2 ds8f = __half22float2(ds8);
const float d5d8 = dm5f.x * ds8f.x;
const float m5s8 = dm5f.y * ds8f.y;
#endif // GGML_CUDA_F16
#endif // FAST_FP16_AVAILABLE
// scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
@ -201,7 +201,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
}
#ifdef GGML_CUDA_F16
#ifdef FAST_FP16_AVAILABLE
const float2 tmp = __half22float2(__hmul2(dm8, ds8));
const float d8d8 = tmp.x;
const float m8s8 = tmp.y;
@ -210,7 +210,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
const float2 ds8f = __half22float2(ds8);
const float d8d8 = dm8f.x * ds8f.x;
const float m8s8 = dm8f.y * ds8f.y;
#endif // GGML_CUDA_F16
#endif // FAST_FP16_AVAILABLE
// scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
return sumi*d8d8 + m8s8 / (QI8_1 / vdr);

View file

@ -96,10 +96,6 @@ if (MUSAToolkit_FOUND)
add_compile_definitions(GGML_CUDA_NO_FA)
endif()
if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
add_compile_definitions(GGML_CUDA_F16)
endif()
if (GGML_CUDA_NO_PEER_COPY)
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
endif()