Programmatic Dependent Launch (PDL) for more performance on newer NVIDIA GPUs (Hopper+) (#22522)

* Adds initial PDL setup.

* Adds PDL barriers based on simple heuristic: place "sync" before first input pointer access, and "launch" after last write, e.g. to tensors like dst.

* Further optimization pass of the first half of kernels

* Optimized PDL barriers for the second batch of kernels

* Further refinements after rebase.

* Moves pdl logic to separate function, removes some whitespace

* Strips post-hoc PDL logic

* Adds stream capture PDL setup. Enrolls quantize_q8_1 to leverage pdl to
overlap execution with previous kernels

* Enrolls mul_mat_vec_q, rms_norm_f32 and k_bin_bcast (partly) into PDL

* Enrolls mmvf, rope, set-rows and topk kernels for gpt-oss into PDL

* Introduce ggml_cuda_kernel_launch, to abstract away cudaLaunchKernelEx,
to enable hip/musa compatibility

* Enrolls cpy_scalar_contiguous, k_get_rows_float and rms_norm_f32

* Enrolls flash_attn_combine_results

* Fix: Drops needless and broken check of CUDA arch for PDL. PDL either
works or is without effect.

* Enrolls flash-attention kernels to pdl

* Fix: inlines ggml_cuda_kernel_launch, and uses perfect forwarding for
kernels args. This fixes PDL.

* Perf: Enrolls k_bin_bcast variadic template invocation into PDL, via
and template alias and template expansion

* Enrolls all remaining kernels for qwen3-coder-next into PDL

* Remove all PDL LC calls to create a baseline

* Added LC according to internal guidance and tested kernel performance.

* Enrols missing qwen3-5 kernels passively into PDL.

* Kernel optimizations (LC signals) for qwen3.5

* Enrolls ssm-scan kernels into PDL

* Adds GGML_CUDA_PDL command line option to toggle PDL.

* Fix: Ada and lower compilation by guarding PDL calls correctly

* Cleanup: Removes commented out GGML_CUDA_PDL_LC

* Cleanup: Removes experimental comments

* Adds 90-virtual to build script so that Hopper GPUs can leverage PDL.

* Adds stricter checks to enable PDL, adds env-check to disable it, and removes now superfluous compile option to enable PDL.

* Fix: Correct PDL en/disablement based on device-side arch check. Host
side check is UB. Required moving from macros to inlined functions

* Fix: default-disable PDL. Enable by setting GGML_CUDA_ENABLE_PDL=1

* Enable PDL by default for Hopper+ devices

* Enrolls softcap_f32 and two flash_attn kernels into PDL.

* Improves flash attn PDL barrier placement

* Fix: Perf regression on ada; excludes ada and below from PDL launches

* Improves some sync barrier placements

* Drops superfluous constructor

* Adds #endif guard comments

* Reverts experimental change to top-k-moe.cu, which moved expensive allocations
in front of the PDL barrier. It did not have a meaningful impact.

* Exchanges GGML_CUDA_DISABLE_PDL with GGML_CUDA_PDL. IFF GGML_CUDA_PDL=0
PDL is disabled

* Revert "Drops superfluous constructor". Adds const to remaining
arguments

This reverts commit 12b1d250da0089ae02a9bb71bbb3fd6d70f6f2f1.

* Cleanup: Removes and fixes some comments and whitespace

* Clarifies comment of sync-barrier position

* Relocates and refactors PDL launch functions and accessories

* Adds error checking to the regular kernel launch path

* Drops "auto" in favor of "ggml_cuda_kernel_params"

* Adds "const" to ggml_cuda_kernel_launch_params

* [Whitespace] Adds final newline to common.cuh to make editorconfig CI job happy
This commit is contained in:
Andreas Kieslinger 2026-05-20 13:59:02 +02:00 committed by GitHub
parent 29f1482221
commit e947228222
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 310 additions and 113 deletions

View file

@ -15,6 +15,7 @@ if (CUDAToolkit_FOUND)
# 80 == Ampere, asynchronous data loading, faster tensor core instructions
# 86 == RTX 3000, needs CUDA v11.1
# 89 == RTX 4000, needs CUDA v11.8
# 90 == Hopper H100/200, needs CUDA v11.8
# 120 == Blackwell, needs CUDA v12.8, FP4 tensor cores
#
# XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run
@ -33,7 +34,7 @@ if (CUDAToolkit_FOUND)
list(APPEND CMAKE_CUDA_ARCHITECTURES 75-virtual 80-virtual 86-real)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real)
list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real 90-virtual)
endif()
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")

View file

@ -2,6 +2,9 @@
#include <cstdint>
#include <utility>
template<typename T, size_t>
using type_for_index = T;
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
return b;
GGML_UNUSED(a);
@ -52,6 +55,7 @@ static __global__ void k_bin_bcast(const src0_t * src0,
const int s12,
const int s13,
src1_ptrs... src1s) {
ggml_cuda_pdl_lc();
const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
@ -72,6 +76,7 @@ static __global__ void k_bin_bcast(const src0_t * src0,
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;
ggml_cuda_pdl_sync();
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
const uint32_t i10 = fastmodulo(i0, ne10);
@ -141,6 +146,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
const int i10 = fastmodulo(i0, ne10);
ggml_cuda_pdl_sync();
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
@ -282,35 +288,24 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);
if constexpr (sizeof...(I) > 0) {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
{
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)block_num, block_size, 0, stream);
ggml_cuda_kernel_launch(k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t, type_for_index<const src1_t *, I>...>, launch_params,
src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
ne12, ne13,
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13);
}
} else {
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
if constexpr (sizeof...(I) > 0) {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
{
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
ggml_cuda_kernel_launch(k_bin_bcast<bin_op, src0_t, src1_t, dst_t, type_for_index<const src1_t *, I>...>, launch_params,
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
/*s0,*/ s1, s2, s3,
s00 ,s01, s02, s03,
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13);
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
}
}
}
@ -333,6 +328,7 @@ static __global__ void k_repeat_back(
}
T sum = 0;
ggml_cuda_pdl_sync();
for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {

View file

@ -5,6 +5,7 @@
#include "ggml-cuda.h"
#include <cstdint>
#include <cstdlib>
#include <memory>
#if defined(GGML_USE_HIP)
@ -27,6 +28,7 @@
#include <cstdio>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#if defined(GGML_USE_HIP)
@ -50,6 +52,7 @@
#define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_ADA_LOVELACE 890
#define GGML_CUDA_CC_HOPPER 900
// While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see
// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms
#define GGML_CUDA_CC_BLACKWELL 1200
@ -107,6 +110,24 @@
# define GGML_CUDA_USE_CUB
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
// PDL host-side support (cudaLaunchKernelEx) requires CUDART >= 11.8 and excludes HIP/MUSA.
// __CUDA_ARCH__ is undefined in host passes; GPU arch check happens in device-side code.
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11080
# define GGML_CUDA_USE_PDL
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11080
static __device__ __forceinline__ void ggml_cuda_pdl_sync() {
#if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
cudaGridDependencySynchronize();
#endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
}
static __device__ __forceinline__ void ggml_cuda_pdl_lc() {
#if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
cudaTriggerProgrammaticLaunchCompletion();
#endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
}
#ifdef __CUDA_ARCH_LIST__
constexpr bool ggml_cuda_has_arch_impl(int) {
return false;
@ -165,6 +186,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
static const char * cublas_get_error_str(const cublasStatus_t err) {
return cublasGetStatusString(err);
@ -1487,3 +1509,67 @@ struct ggml_cuda_mm_fusion_args_device {
const void * gate_bias = nullptr;
ggml_glu_op glu_op;
};
struct ggml_cuda_kernel_launch_params {
dim3 block_nums;
dim3 block_dims;
size_t shmem;
cudaStream_t stream;
// size_t shmem
ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, const size_t shmem_, const cudaStream_t stream_)
: block_nums(block_nums_), block_dims(block_dims_), shmem(shmem_), stream(stream_) {}
// Some call sites pass ints instead of the required size_t. This 2nd constructor casts int->size_t to avoid these -Wnarrowing warnings.
ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, const int shmem_, const cudaStream_t stream_)
: block_nums(block_nums_), block_dims(block_dims_), shmem((size_t)shmem_), stream(stream_) {}
};
#if defined(GGML_CUDA_USE_PDL)
struct ggml_cuda_pdl_config {
cudaLaunchAttribute attr;
cudaLaunchConfig_t cfg;
ggml_cuda_pdl_config(const ggml_cuda_kernel_launch_params & params) {
attr.id = cudaLaunchAttributeProgrammaticStreamSerialization;
attr.val.programmaticStreamSerializationAllowed = 1;
cfg = {};
cfg.gridDim = params.block_nums;
cfg.blockDim = params.block_dims;
cfg.dynamicSmemBytes = params.shmem;
cfg.stream = params.stream;
cfg.attrs = &attr;
cfg.numAttrs = 1;
}
// Delete due to &attr
ggml_cuda_pdl_config(const ggml_cuda_pdl_config&) = delete;
ggml_cuda_pdl_config& operator=(const ggml_cuda_pdl_config&) = delete;
ggml_cuda_pdl_config& operator=(ggml_cuda_pdl_config&&) = delete;
};
#endif //defined(GGML_CUDA_USE_PDL)
template<typename Kernel, typename... Args>
static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_kernel_launch_params & launch_params, Args&&... args) {
#if defined(GGML_CUDA_USE_PDL)
static const bool env_pdl_enabled = []() {
const char * env = getenv("GGML_CUDA_PDL");
return env == nullptr || std::atoi(env) != 0;
}();
if (env_pdl_enabled && ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_HOPPER) {
auto pdl_cfg = ggml_cuda_pdl_config(launch_params);
CUDA_CHECK(cudaLaunchKernelEx(&pdl_cfg.cfg, kernel, std::forward<Args>(args)... ));
return;
}
#endif //defined(GGML_CUDA_USE_PDL)
kernel<<<launch_params.block_nums, launch_params.block_dims, launch_params.shmem, launch_params.stream>>>(std::forward<Args>(args)... );
CUDA_CHECK(cudaGetLastError());
}

View file

@ -15,6 +15,7 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont
const int64_t n = ne0 * ne1 * ne2;
ggml_cuda_pdl_sync();
for (int64_t i = (int64_t) blockIdx.x * blockDim.x + threadIdx.x; i < n; i += (int64_t) blockDim.x * gridDim.x) {
if constexpr (dim == 0) {
const int64_t row = i / ne0;
@ -64,8 +65,8 @@ static void concat_f32_cuda(const float * x,
const int num_blocks = (n + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
if (dim == 0) {
concat_f32_cont<0>
<<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream);
ggml_cuda_kernel_launch(concat_f32_cont<0>, launch_params,x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2);
return;
}
if (dim == 1) {

View file

@ -16,6 +16,7 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
const int64_t nb12, const int64_t nb13) {
ggml_cuda_pdl_lc();
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) {
@ -36,6 +37,7 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
ggml_cuda_pdl_sync();
cpy_1(cx + x_offset, cdst + dst_offset);
}
@ -59,6 +61,7 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
__shared__ float tile[2][CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
int cur_tile_buf = 0;
ggml_cuda_pdl_sync();
#pragma unroll
for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) {
@ -142,6 +145,7 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne,
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
ggml_cuda_pdl_sync();
cpy_blck(cx + x_offset, cdst + dst_offset);
}
@ -168,6 +172,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne,
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
ggml_cuda_pdl_sync();
cpy_blck(cx + x_offset, cdst + dst_offset);
}
@ -182,6 +187,7 @@ static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const
const src_t * x = (const src_t *) cx;
dst_t * dst = (dst_t *) cdst;
ggml_cuda_pdl_sync();
dst[i] = ggml_cuda_cast<dst_t>(x[i]);
}
@ -192,8 +198,8 @@ cudaStream_t stream) {
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream);
ggml_cuda_kernel_launch(cpy_scalar_contiguous<src_t, dst_t>, launch_params, cx, cdst, ne);
}
template<typename src_t, typename dst_t, bool transposed = false>
@ -223,13 +229,15 @@ static void ggml_cpy_scalar_cuda(
GGML_ASSERT(grid_z < USHRT_MAX);
dim3 dimGrid(grid_x, grid_y, grid_z);
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(dimGrid, dimBlock, 0, stream);
ggml_cuda_kernel_launch(cpy_scalar_transpose<dst_t>, launch_params,
cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} else {
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream);
ggml_cuda_kernel_launch(cpy_scalar<cpy_1_scalar<src_t, dst_t>>, launch_params,
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
}

View file

@ -636,6 +636,7 @@ static __global__ void flash_attn_mask_to_KV_max(
if (tid < WARP_SIZE) {
buf_iw[tid] = 1;
}
ggml_cuda_pdl_sync();
__syncthreads();
int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
@ -687,6 +688,7 @@ static __global__ void flash_attn_stream_k_fixup_uniform(
const uint3 fd_iter_j_z,
const uint3 fd_iter_j) {
constexpr int ncols = ncols1*ncols2;
ggml_cuda_pdl_lc();
const int tile_idx = blockIdx.x; // One block per output tile.
const int j = blockIdx.y;
@ -718,6 +720,7 @@ static __global__ void flash_attn_stream_k_fixup_uniform(
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
ggml_cuda_pdl_sync();
// Load the partial result that needs a fixup
float dst_val = *dst;
float max_val;
@ -809,6 +812,7 @@ static __global__ void flash_attn_stream_k_fixup_general(
float dst_val = 0.0f;
float max_val = 0.0f;
float rowsum = 0.0f;
ggml_cuda_pdl_sync();
{
dst_val = *dst;
@ -867,6 +871,7 @@ static __global__ void flash_attn_combine_results(
const float2 * __restrict__ VKQ_meta,
float * __restrict__ dst,
const int parallel_blocks) {
ggml_cuda_pdl_lc();
// Dimension 0: threadIdx.x
// Dimension 1: blockIdx.x
// Dimension 2: blockIdx.y
@ -890,6 +895,7 @@ static __global__ void flash_attn_combine_results(
__builtin_assume(tid < D);
extern __shared__ float2 meta[];
ggml_cuda_pdl_sync();
for (int i = tid; i < 2*parallel_blocks; i += D) {
((float *) meta)[i] = ((const float *)VKQ_meta) [i];
}
@ -1146,7 +1152,9 @@ void launch_fattn(
const uint3 ne01 = init_fastdiv_values(Q->ne[1]);
GGML_ASSERT(block_dim.x % warp_size == 0);
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num, block_dim, nbytes_shared, main_stream);
ggml_cuda_kernel_launch(fattn_kernel, launch_params,
(const char *) Q->data,
K_data,
V_data,
@ -1176,9 +1184,9 @@ void launch_fattn(
const dim3 block_dim_combine(DV, 1, 1);
const dim3 blocks_num_combine = {(unsigned)ntiles_dst, ncols1, ncols2};
flash_attn_stream_k_fixup_uniform<DV, ncols1, ncols2>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
((float *) KQV->data, dst_tmp_meta.ptr,
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_combine, block_dim_combine, 0, main_stream);
ggml_cuda_kernel_launch(flash_attn_stream_k_fixup_uniform<DV, ncols1, ncols2>, launch_params,
(float *) KQV->data, dst_tmp_meta.ptr,
Q->ne[1], Q->ne[2], K->ne[2], nblocks_sk,
gqa_ratio, bpt, fd0, fd1, fd2);
} else if (ntiles_dst % blocks_num.x != 0) {
@ -1193,9 +1201,9 @@ void launch_fattn(
const dim3 block_dim_combine(DV, 1, 1);
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
flash_attn_stream_k_fixup_general<DV, ncols1, ncols2>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
((float *) KQV->data, dst_tmp_meta.ptr,
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_combine, block_dim_combine, 0, main_stream);
ggml_cuda_kernel_launch(flash_attn_stream_k_fixup_general<DV, ncols1, ncols2>, launch_params,
(float *) KQV->data, dst_tmp_meta.ptr,
Q->ne[1], Q->ne[2], gqa_ratio, total_work,
fd_k_j_z_ne12, fd_k_j_z, fd_k_j, fd_k);
}
@ -1204,9 +1212,9 @@ void launch_fattn(
const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
flash_attn_combine_results<DV>
<<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream);
ggml_cuda_kernel_launch(flash_attn_combine_results<DV>, launch_params,
dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
}
CUDA_CHECK(cudaGetLastError());
}

View file

@ -1724,6 +1724,7 @@ static __global__ void flash_attn_ext_f16(
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
ggml_cuda_pdl_sync(); // TODO optimize placement
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE))
// Skip unused kernel variants for faster compilation:

View file

@ -894,6 +894,8 @@ static __global__ void flash_attn_tile(
}
float KQ_sum[cpw] = {0.0f};
ggml_cuda_pdl_sync();
// Load Q data, convert to FP16 if fast:
#pragma unroll
for (int jc0 = 0; jc0 < cpw; ++jc0) {

View file

@ -40,6 +40,7 @@ static __global__ void flash_attn_ext_vec(
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
ggml_cuda_pdl_lc();
#ifdef FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
@ -136,6 +137,8 @@ static __global__ void flash_attn_ext_vec(
#endif // V_DOT2_F32_F16_AVAILABLE
int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
ggml_cuda_pdl_sync();
if constexpr (Q_q8_1) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {

View file

@ -86,6 +86,7 @@ static __global__ void flash_attn_ext_f16(
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
ggml_cuda_pdl_sync();
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.

View file

@ -1,4 +1,5 @@
#include "gated_delta_net.cuh"
#include "ggml-cuda/common.cuh"
template <int S_v, bool KDA, bool keep_rs_t>
__global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2)
@ -53,6 +54,7 @@ gated_delta_net_cuda(const float * q,
float s_shard[rows_per_lane];
// state is stored transposed: M[col][i] = S[i][col], row col is contiguous
ggml_cuda_pdl_sync();
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
@ -189,28 +191,29 @@ static void launch_gated_delta_net(
int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream);
switch (S_v) {
case 16:
gated_delta_net_cuda<16, KDA, keep_rs_t><<<grid_dims, block_dims, 0, stream>>>(
ggml_cuda_kernel_launch(gated_delta_net_cuda<16, KDA, keep_rs_t>, launch_params,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
break;
case 32:
gated_delta_net_cuda<32, KDA, keep_rs_t><<<grid_dims, block_dims, 0, stream>>>(
ggml_cuda_kernel_launch(gated_delta_net_cuda<32, KDA, keep_rs_t>, launch_params,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
break;
case 64: {
gated_delta_net_cuda<64, KDA, keep_rs_t><<<grid_dims, block_dims, 0, stream>>>(
ggml_cuda_kernel_launch(gated_delta_net_cuda<64, KDA, keep_rs_t>, launch_params,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
break;
}
case 128: {
gated_delta_net_cuda<128, KDA, keep_rs_t><<<grid_dims, block_dims, 0, stream>>>(
ggml_cuda_kernel_launch(gated_delta_net_cuda<128, KDA, keep_rs_t>, launch_params,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);

View file

@ -11,6 +11,7 @@ static __global__ void k_get_rows(
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
ggml_cuda_pdl_sync();
for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) {
for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
@ -48,6 +49,8 @@ static __global__ void k_get_rows_float(
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
ggml_cuda_pdl_lc();
ggml_cuda_pdl_sync();
for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) {
for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
@ -83,6 +86,7 @@ static __global__ void k_get_rows_back_float(
float sum = 0.0f;
ggml_cuda_pdl_sync();
for (int64_t i = 0; i < nrows_grad; ++i) {
if (rows[i] != dst_row) {
continue;
@ -156,7 +160,8 @@ static void get_rows_cuda_float(
GGML_ASSERT(ne11 <= std::numeric_limits<uint32_t>::max() / ne12);
const uint3 ne12_fdv = init_fastdiv_values(ne12);
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{block_nums, block_dims, 0, stream};
ggml_cuda_kernel_launch(k_get_rows_float<src0_t, dst_t>, launch_params,
src0_d, src1_d, dst_d,
ne00, /*ne01, ne02, ne03,*/
/*ne10,*/ ne11, ne12_fdv, /*ne13,*/

View file

@ -67,9 +67,11 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
// See discussion in: https://github.com/ggml-org/llama.cpp/pull/15132
if ((nrows / nsm) < 2) {
const dim3 block_dims(512, 1, 1);
reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
ggml_cuda_kernel_launch(reduce_rows_f32</*norm=*/true>, launch_params, src0_d, dst_d, ncols);
} else {
const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
ggml_cuda_kernel_launch(reduce_rows_f32</*norm=*/true>, launch_params, src0_d, dst_d, ncols);
}
}

View file

@ -21,6 +21,7 @@ static __global__ void mul_mat_vec_f(
int channel_y;
int sample_dst;
ggml_cuda_pdl_sync();
if constexpr (is_multi_token_id) {
// Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
token_idx = blockIdx.z;
@ -298,6 +299,7 @@ static __global__ void mul_mat_vec_f(
static_assert(std::is_same_v<T, void>, "unsupported type");
}
ggml_cuda_pdl_lc();
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
@ -382,11 +384,13 @@ static void mul_mat_vec_f_switch_fusion(
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const int ids_stride, const cudaStream_t stream) {
const ggml_cuda_kernel_launch_params launch_params = {block_nums, block_dims, nbytes_shared, stream};
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (ncols_dst == 1) {
if (has_fusion) {
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
ggml_cuda_kernel_launch(mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true, is_multi_token_id>, launch_params,
x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
return;
@ -395,8 +399,8 @@ static void mul_mat_vec_f_switch_fusion(
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
ggml_cuda_kernel_launch(mul_mat_vec_f<T, type_acc, ncols_dst, block_size, false, is_multi_token_id>, launch_params,
x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);

View file

@ -424,6 +424,7 @@ static __global__ void mul_mat_vec_q(
uint32_t channel_y;
uint32_t sample_dst;
ggml_cuda_pdl_sync();
channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
sample_dst = blockIdx.z;
@ -683,8 +684,9 @@ static void mul_mat_vec_q_switch_fusion(
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (c_ncols_dst == 1) {
if (has_fusion) {
mul_mat_vec_q<type, c_ncols_dst, true, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, nbytes_shared, stream);
ggml_cuda_kernel_launch(mul_mat_vec_q<type, c_ncols_dst, true, small_k>, launch_params,
vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
return;
@ -693,8 +695,9 @@ static void mul_mat_vec_q_switch_fusion(
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
mul_mat_vec_q<type, c_ncols_dst, false, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, nbytes_shared, stream);
ggml_cuda_kernel_launch(mul_mat_vec_q<type, c_ncols_dst, false, small_k>, launch_params,
vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
}

View file

@ -18,6 +18,7 @@ static __global__ void norm_f32(
float2 mean_var = make_float2(0.0f, 0.0f);
ggml_cuda_pdl_sync();
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
mean_var.x += xi;
@ -46,6 +47,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
float tmp = 0.0f; // partial sum for thread in warp
ggml_cuda_pdl_sync();
for (int j = start; j < end; j += block_size) {
tmp += x[j];
}
@ -95,6 +97,7 @@ static __global__ void rms_norm_f32(const float * x,
const uint3 add_nrows_packed = make_uint3(0, 0, 0),
const uint3 add_nchannels_packed = make_uint3(0, 0, 0),
const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) {
ggml_cuda_pdl_lc();
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
@ -124,6 +127,7 @@ static __global__ void rms_norm_f32(const float * x,
float tmp = 0.0f; // partial sum for thread in warp
ggml_cuda_pdl_sync();
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
tmp += xi * xi;
@ -163,6 +167,7 @@ static __global__ void rms_norm_back_f32(
float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass
float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs
ggml_cuda_pdl_sync();
for (int col = tid; col < ncols; col += block_size) {
const float xfi = xf[col];
sum_xx += xfi * xfi;
@ -253,6 +258,7 @@ static __global__ void l2_norm_f32(
float tmp = 0.0f; // partial sum for thread in warp
ggml_cuda_pdl_sync();
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
tmp += xi * xi;
@ -261,6 +267,7 @@ static __global__ void l2_norm_f32(
// sum up partial sums
extern __shared__ float s_sum[];
tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
ggml_cuda_pdl_lc();
// from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
@ -300,10 +307,19 @@ static void rms_norm_f32_cuda(
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
rms_norm_f32<256, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
const ggml_cuda_kernel_launch_params launch_params = {blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream};
ggml_cuda_kernel_launch(rms_norm_f32<256, false>, launch_params,
x, dst, ncols, stride_row, stride_channel, stride_sample, eps,
// underlying cudaLaunchKernelEx does not support default params
nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0),
nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0));
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream};
ggml_cuda_kernel_launch(rms_norm_f32<1024, false>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps,
// underlying cudaLaunchKernelEx does not support default params
nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0),
nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0));
}
}
@ -346,14 +362,20 @@ static void rms_norm_mul_f32_cuda(const float * x,
const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
rms_norm_f32<256, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream};
ggml_cuda_kernel_launch(rms_norm_f32<256, true>, launch_params,
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed,
// underlying cudaLaunchKernelEx does not support default params
nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0));
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream};
ggml_cuda_kernel_launch(rms_norm_f32<1024, true>, launch_params,
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed,
// underlying cudaLaunchKernelEx does not support default params
nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0));
}
} else {
const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
@ -367,14 +389,16 @@ static void rms_norm_mul_f32_cuda(const float * x,
const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
rms_norm_f32<256, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims,block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream};
ggml_cuda_kernel_launch(rms_norm_f32<256, true, true>, launch_params,
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
add_nchannels_packed, add_nsamples_packed);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream};
ggml_cuda_kernel_launch(rms_norm_f32<1024, true, true>, launch_params,
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
@ -399,10 +423,12 @@ static void l2_norm_f32_cuda(
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, 0, stream};
ggml_cuda_kernel_launch(l2_norm_f32<WARP_SIZE>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
l2_norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream};
ggml_cuda_kernel_launch(l2_norm_f32<1024>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}

View file

@ -6,6 +6,7 @@ static __global__ void quantize_q8_1(
const float * __restrict__ x, void * __restrict__ vy,
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const uint32_t ne1, const uint3 ne2) {
ggml_cuda_pdl_lc();
const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
if (i0 >= ne0) {
@ -28,6 +29,7 @@ static __global__ void quantize_q8_1(
const int64_t ib = i_cont / QK8_1; // block index
const int64_t iqs = i_cont % QK8_1; // quant index
ggml_cuda_pdl_sync();
const float xi = i0 < ne00 ? x[i03*s03 + i02*s02 + i01*s01 + i00] : 0.0f;
float amax = fabsf(xi);
float sum = xi;
@ -196,6 +198,7 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
const int64_t i2 = blockIdx.z % ne2;
const int64_t i3 = blockIdx.z / ne2;
ggml_cuda_pdl_sync();
const int64_t i01 = ids ? ids[i1] : i1;
const int64_t i02 = i2;
const int64_t i03 = i3;
@ -288,6 +291,7 @@ static __global__ void quantize_mmq_q8_1(
const int64_t i3 = blockIdx.z / ne2;
const int64_t i00 = i0;
ggml_cuda_pdl_sync();
const int64_t i01 = ids ? ids[i1] : i1;
const int64_t i02 = i2;
const int64_t i03 = i3;
@ -378,7 +382,8 @@ void quantize_row_q8_1_cuda(
const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, block_size, 0, stream);
ggml_cuda_kernel_launch(quantize_q8_1, launch_params, x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv);
GGML_UNUSED(type_src0);
}

View file

@ -10,6 +10,8 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r
const int num_unroll = 8;
float temp[num_unroll];
float sum_temp[num_unroll] = { 0.0f };
ggml_cuda_pdl_sync();
for (int i = col; i < ncols;) {
for (int j = 0; j < num_unroll; ++j) {
if (i < ncols) {

View file

@ -134,6 +134,7 @@ static __global__ void rope_neox(const T * x,
const float * freq_factors,
const int64_t * row_indices,
const int set_rows_stride) {
ggml_cuda_pdl_lc();
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne00) {
@ -148,6 +149,7 @@ static __global__ void rope_neox(const T * x,
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
ggml_cuda_pdl_sync();
// Fusion optimization: ROPE + VIEW + SET_ROWS.
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
@ -216,6 +218,7 @@ static __global__ void rope_multi(const T * x,
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
ggml_cuda_pdl_sync();
if (i0 >= n_dims) {
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
@ -300,6 +303,7 @@ static __global__ void rope_vision(const T * x,
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
ggml_cuda_pdl_sync();
const int sect_dims = sections.v[0] + sections.v[1];
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
@ -399,13 +403,14 @@ static void rope_neox_cuda(const T * x,
const dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
const ggml_cuda_kernel_launch_params launch_params = {block_nums, block_dims, 0, stream};
if (freq_factors == nullptr) {
rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>(
ggml_cuda_kernel_launch(rope_neox<forward, false, T, D>, launch_params,
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
} else {
rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>(
ggml_cuda_kernel_launch(rope_neox<forward, true, T, D>, launch_params,
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
}
@ -443,11 +448,13 @@ static void rope_multi_cuda(const T * x,
const float theta_scale = powf(freq_base, -2.0f / n_dims);
if (freq_factors == nullptr) {
rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
ggml_cuda_kernel_launch(rope_multi<forward, false, T>, launch_params,
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
} else {
rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
ggml_cuda_kernel_launch(rope_multi<forward, true, T>, launch_params,
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
}

View file

@ -3,9 +3,11 @@
#define MAX_GRIDDIM_X 0x7FFFFFFF
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) {
ggml_cuda_pdl_lc();
int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x;
ggml_cuda_pdl_sync();
for (int64_t i = tid; i < nelements; i += stride) {
dst[i] = scale * x[i] + bias;
}
@ -13,7 +15,8 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) {
const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
scale_f32<<<MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, nelements);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream);
ggml_cuda_kernel_launch(scale_f32, launch_params, x, dst, scale, bias, nelements);
}
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

View file

@ -53,6 +53,7 @@ static __global__ void k_set_rows_quant(const float * __restrict__ src0,
const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
const int64_t i10 = i01;
ggml_cuda_pdl_sync();
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
@ -157,7 +158,9 @@ static __global__ void k_set_rows(const src_t * __restrict__ src0,
const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
const int64_t i10 = i01;
ggml_cuda_pdl_sync();
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
ggml_cuda_pdl_lc();
const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3;
@ -203,9 +206,11 @@ static void set_rows_cuda(
const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
k_set_rows<<<grid_size, block_size, 0, stream>>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01,
s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd,
ne11_fd, ne12_fd);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_size, block_size, 0, stream);
ggml_cuda_kernel_launch(k_set_rows<src_t, idx_t, dst_t>, launch_params,
src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01,
s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd,
ne11_fd, ne12_fd);
}
}

View file

@ -1,18 +1,21 @@
#include "softcap.cuh"
static __global__ void softcap_f32(const float * x, float * dst, const float scale, const float softcap, const int k) {
ggml_cuda_pdl_lc();
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
ggml_cuda_pdl_sync();
dst[i] = tanhf(scale * x[i]) * softcap;
}
static void softcap_f32_cuda(const float * x, float * dst, const float scale, const float softcap, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SOFTCAP_BLOCK_SIZE - 1) / CUDA_SOFTCAP_BLOCK_SIZE;
softcap_f32<<<num_blocks, CUDA_SOFTCAP_BLOCK_SIZE, 0, stream>>>(x, dst, scale, softcap, k);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, CUDA_SOFTCAP_BLOCK_SIZE, 0, stream);
ggml_cuda_kernel_launch(softcap_f32, launch_params, x, dst, scale, softcap, k);
}
// fused GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE

View file

@ -1,3 +1,4 @@
#include "common.cuh"
#include "ssm-conv.cuh"
#include "unary.cuh"
@ -7,6 +8,7 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
const int64_t n_t) {
ggml_cuda_pdl_lc();
GGML_UNUSED(src0_nb0);
const int tid = threadIdx.x;
const int bidx = blockIdx.x;
@ -23,6 +25,7 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
float x[d_conv] = { 0.0f };
float w[d_conv] = { 0.0f };
ggml_cuda_pdl_sync();
#pragma unroll
for (size_t j = 0; j < d_conv; j++) {
w[j] = w_block[tid * stride_w + j];
@ -128,8 +131,9 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const floa
constexpr int kNC = decltype(NC)::value;
if (n_t <= 32) {
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
ssm_conv_f32<apply_silu, threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream);
ggml_cuda_kernel_launch(ssm_conv_f32<apply_silu, threads, kNC>, launch_params, src0, src1, bias, src0_nb0, src0_nb1,
src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else {
const int64_t split_n_t = 32;
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);

View file

@ -26,6 +26,7 @@ __global__ void __launch_bounds__(splitD, 1)
const int64_t s_off, const int64_t d_inner, const int64_t L_param)
{
const size_t L = L_template == 0 ? L_param : L_template;
ggml_cuda_pdl_sync();
const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2);
const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float));
const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float));
@ -135,6 +136,7 @@ __global__ void __launch_bounds__(d_state, 1)
const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
ggml_cuda_pdl_sync();
// TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase
const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float)));
@ -206,7 +208,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
constexpr int num_warps = threads/WARP_SIZE;
const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
ssm_scan_f32_group<128/WARP_SIZE, 128><<<blocks, threads, 0, stream>>>(
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream);
ggml_cuda_kernel_launch(ssm_scan_f32_group<128/WARP_SIZE, 128>, launch_params,
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
@ -215,7 +218,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
constexpr int num_warps = threads/WARP_SIZE;
const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
ssm_scan_f32_group<256/WARP_SIZE, 256><<<blocks, threads, 0, stream>>>(
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream);
ggml_cuda_kernel_launch(ssm_scan_f32_group<256/WARP_SIZE, 256>, launch_params,
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
@ -231,58 +235,59 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
if (d_state == 16) {
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, smem_size, stream);
switch (n_tok)
{
case 1:
ssm_scan_f32<threads, 16, 1><<<blocks, threads, smem_size, stream>>>(
ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 1>, launch_params,
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 2:
ssm_scan_f32<threads, 16, 2><<<blocks, threads, smem_size, stream>>>(
ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 2>, launch_params,
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 3:
ssm_scan_f32<threads, 16, 3><<<blocks, threads, smem_size, stream>>>(
ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 3>, launch_params,
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 4:
ssm_scan_f32<threads, 16, 4><<<blocks, threads, smem_size, stream>>>(
ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 4>, launch_params,
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 5:
ssm_scan_f32<threads, 16, 5><<<blocks, threads, smem_size, stream>>>(
ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 5>, launch_params,
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 6:
ssm_scan_f32<threads, 16, 6><<<blocks, threads, smem_size, stream>>>(
ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 6>, launch_params,
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 7:
ssm_scan_f32<threads, 16, 7><<<blocks, threads, smem_size, stream>>>(
ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 7>, launch_params,
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 8:
ssm_scan_f32<threads, 16, 8><<<blocks, threads, smem_size, stream>>>(
ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 8>, launch_params,
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
default:
ssm_scan_f32<threads, 16, 0><<<blocks, threads, smem_size, stream>>>(
ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 0>, launch_params,
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);

View file

@ -7,10 +7,12 @@ void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int
const dim3 block_nums(nrows, 1, 1);
if ((nrows / nsm) < 2) {
const dim3 block_dims(512, 1, 1);
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
ggml_cuda_kernel_launch(reduce_rows_f32</*norm=*/false>, launch_params, x, dst, ncols);
} else {
const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
ggml_cuda_kernel_launch(reduce_rows_f32</*norm=*/false>, launch_params, x, dst, ncols);
}
}
@ -34,10 +36,12 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if ((nrows / nsm) < 2) {
// Increase num threads to 512 for small nrows to better hide the latency
const dim3 block_dims(512, 1, 1);
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
ggml_cuda_kernel_launch(reduce_rows_f32</*norm=*/false>, launch_params, src0_d, dst_d, ncols);
} else {
// Enough active SMs to hide latency, use smaller blocks to allow better scheduling
const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
ggml_cuda_kernel_launch(reduce_rows_f32</*norm=*/false>, launch_params, src0_d, dst_d, ncols);
}
}

View file

@ -105,6 +105,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
wt[i] = -INFINITY;
}
ggml_cuda_pdl_sync();
#pragma unroll
for (int i = 0; i < n_experts; i += WARP_SIZE) {
const int expert = i + threadIdx.x;
@ -161,6 +162,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
output_weights[i] = 0.f;
}
ggml_cuda_pdl_lc();
for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0];
int max_expert = threadIdx.x;
@ -271,51 +273,52 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
cudaStream_t stream = ctx.stream();
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream);
switch (n_expert) {
case 1:
topk_moe_cuda<1, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
ggml_cuda_kernel_launch(topk_moe_cuda<1, has_bias>, launch_params,
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
break;
case 2:
topk_moe_cuda<2, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
ggml_cuda_kernel_launch(topk_moe_cuda<2, has_bias>, launch_params,
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
break;
case 4:
topk_moe_cuda<4, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
ggml_cuda_kernel_launch(topk_moe_cuda<4, has_bias>, launch_params,
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
break;
case 8:
topk_moe_cuda<8, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
ggml_cuda_kernel_launch(topk_moe_cuda<8, has_bias>, launch_params,
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
break;
case 16:
topk_moe_cuda<16, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
ggml_cuda_kernel_launch(topk_moe_cuda<16, has_bias>, launch_params,
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
break;
case 32:
topk_moe_cuda<32, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
ggml_cuda_kernel_launch(topk_moe_cuda<32, has_bias>, launch_params,
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
break;
case 64:
topk_moe_cuda<64, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
ggml_cuda_kernel_launch(topk_moe_cuda<64, has_bias>, launch_params,
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
break;
case 128:
topk_moe_cuda<128, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
ggml_cuda_kernel_launch(topk_moe_cuda<128, has_bias>, launch_params,
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
break;
case 256:
topk_moe_cuda<256, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
ggml_cuda_kernel_launch(topk_moe_cuda<256, has_bias>, launch_params,
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
break;
case 512:
topk_moe_cuda<512, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
ggml_cuda_kernel_launch(topk_moe_cuda<512, has_bias>, launch_params,
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
break;
case 576:
topk_moe_cuda<576, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
ggml_cuda_kernel_launch(topk_moe_cuda<576, has_bias>, launch_params,
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
break;
default:
GGML_ASSERT(false && "fatal error");

View file

@ -116,19 +116,22 @@ static __device__ __forceinline__ float op_trunc(float x) {
template <float (*op)(float), typename T>
static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
ggml_cuda_pdl_lc();
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
ggml_cuda_pdl_sync();
dst[i] = (T)op((float)x[i]);
}
template <float (*op)(float), typename T>
static void unary_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
unary_op_kernel<op><<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream);
ggml_cuda_kernel_launch(unary_op_kernel<op, T>, launch_params, x, dst, k);
}
template <float (*op)(float)>
@ -258,6 +261,7 @@ void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
template <float (*op)(float), typename T>
static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) {
ggml_cuda_pdl_lc();
const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
if (i >= k) {
@ -268,13 +272,15 @@ static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst,
const int64_t j0 = (i / n) * o0 + (i % n);
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
ggml_cuda_pdl_sync();
dst[i] = (T)(op((float)x[j0]) * (float)g[j1]);
}
template <float (*op)(float), typename T>
static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) {
const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
unary_gated_op_kernel<op><<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream);
ggml_cuda_kernel_launch(unary_gated_op_kernel<op, T>, launch_params, x, g, dst, k, n, o0, o1);
}
template <float (*op)(float)>