From 9725a313be0528214c4a02fed906ddaf7b3f712e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 25 Apr 2026 14:15:03 +0200 Subject: [PATCH] CUDA: reduce MMQ stream-k overhead (#22298) * CUDA: reduce MMQ stream-k overhead * use 32 bit integers for kbc --- ggml/src/ggml-cuda/mmq.cuh | 277 ++++++++++++++++++------------------- 1 file changed, 138 insertions(+), 139 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index b1a319de9..91a1b737a 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -3478,10 +3478,10 @@ template static __global__ void mul_mat_q( const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst, const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup, - const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, - const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - const int ncols_max) { + const uint3 blocks_per_ne00, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, + const uint3 channel_ratio, const uint3 nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const uint3 sample_ratio, const uint3 nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const uint3 ntx) { // Skip unused template specializations for faster compilation: if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) { @@ -3495,8 +3495,7 @@ static __global__ void mul_mat_q( constexpr int qk = ggml_cuda_type_traits::qk; constexpr int mmq_y = get_mmq_y_device(); - const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x - const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y + const uint32_t nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y // Initialize the ids for writing back data with just the index. // For regular matrix multiplications this is never changed. @@ -3517,8 +3516,9 @@ static __global__ void mul_mat_q( // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: #if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA { - const int wt = blockIdx.z / nchannels_y; - const int zt = blockIdx.z - wt*nchannels_y; + const uint2 tmp2 = fast_div_modulo(blockIdx.z, nchannels_y); + const int wt = tmp2.x; + const int zt = tmp2.y; const int jt = blockIdx.y; const int it = blockIdx.x; @@ -3561,40 +3561,40 @@ static __global__ void mul_mat_q( const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_y_max_j = col_diff - jt*mmq_x - 1; - const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = false; mul_mat_q_process_tile (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, - tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); + tile_x_max_i, tile_y_max_j, 0, blocks_per_ne00.z); return; } #endif // (defined(GGML_USE_HIP) && !defined(CDNA4) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA - constexpr int ITER_K = get_iter_k(type); - - const int64_t blocks_per_ne00 = ncols_x / qk; - constexpr int blocks_per_iter = ITER_K / qk; + constexpr int ITER_K = get_iter_k(type); + constexpr int blocks_per_iter = ITER_K / qk; // kbc == k block continuous, current index in continuous ijk space. - int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; - int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int kbc = int64_t(blockIdx.x) *(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; + int kbc_stop = int64_t(blockIdx.x + 1)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; - kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; - kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter; + kbc -= fastmodulo(kbc, blocks_per_ne00) % blocks_per_iter; + kbc_stop -= fastmodulo(kbc_stop, blocks_per_ne00) % blocks_per_iter; // kb0 == k index when doing the matrix multiplication for an output tile. - int kb0_start = kbc % blocks_per_ne00; - int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc); - while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) { - int tmp = kbc; - const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); - tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); - const int zt = tmp / (ntx*blocks_per_ne00); - tmp -= zt * (ntx*blocks_per_ne00); - const int jt = tmp / blocks_per_ne00; + int kb0_start = fastmodulo(kbc, blocks_per_ne00); + int kb0_stop = min(blocks_per_ne00.z, uint32_t(kb0_start + kbc_stop - kbc)); + while (kbc < kbc_stop && kb0_stop == int(blocks_per_ne00.z)) { + int tmp = fastdiv(kbc, blocks_per_ne00); + uint2 tmp2 = fast_div_modulo(tmp, ntx); + const int jt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nchannels_y); + const int zt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nsamples_y); + const int wt = tmp2.y; + const int it = tmp2.x; // Defaults for regular matrix multiplication: int col_low = 0; @@ -3612,11 +3612,11 @@ static __global__ void mul_mat_q( offset_dst = 0; if (jt*mmq_x >= col_diff) { - kbc += blocks_per_ne00; - kbc -= kbc % blocks_per_ne00; + kbc += blocks_per_ne00.z; + kbc -= fastmodulo(kbc, blocks_per_ne00); kb0_start = 0; - kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + kb0_stop = min(blocks_per_ne00.z, uint32_t(kbc_stop - kbc)); continue; } @@ -3641,32 +3641,34 @@ static __global__ void mul_mat_q( const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_y_max_j = col_diff - jt*mmq_x - 1; - const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. mul_mat_q_process_tile (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); - kbc += blocks_per_ne00; - kbc -= kbc % blocks_per_ne00; + kbc += blocks_per_ne00.z; + kbc -= fastmodulo(kbc, blocks_per_ne00); kb0_start = 0; - kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + kb0_stop = min(blocks_per_ne00.z, uint32_t(kbc_stop - kbc)); } if (kbc >= kbc_stop) { return; } - int tmp = kbc; - const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); - tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); - const int zt = tmp / (ntx*blocks_per_ne00); - tmp -= zt * (ntx*blocks_per_ne00); - const int jt = tmp / blocks_per_ne00; + int tmp = fastdiv(kbc, blocks_per_ne00); + uint2 tmp2 = fast_div_modulo(tmp, ntx); + const int jt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nchannels_y); + const int zt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nsamples_y); + const int wt = tmp2.y; + const int it = tmp2.x; // Defaults for regular matrix multiplication: int col_low = 0; @@ -3708,7 +3710,7 @@ static __global__ void mul_mat_q( const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_y_max_j = col_diff - jt*mmq_x - 1; - const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. mul_mat_q_process_tile @@ -3717,46 +3719,37 @@ static __global__ void mul_mat_q( } template -static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, - const int32_t * expert_bounds, - float * __restrict__ dst, - const float * __restrict__ tmp_last_tile, - const int ncols_x, - const int nrows_x, - const int ncols_dst, - const size_t stride_col_dst, - const int nchannels_y, - const size_t stride_channel_dst, - const int nsamples_y, - const size_t stride_sample_dst, - const int ncols_max) { - constexpr int mmq_y = get_mmq_y_device(); - constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int ITER_K = get_iter_k(type); +__launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device()/2, 1) +static __global__ void mul_mat_q_stream_k_fixup( + const int32_t * __restrict__ ids_dst, const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, + float * __restrict__ tmp_last_tile, const uint3 blocks_per_ne00, const int nrows_x, const int ncols_dst, + const int stride_col_dst, const uint3 nchannels_y, const int stride_channel_dst, const uint3 nsamples_y, + const int stride_sample_dst, const uint3 ntx) { + constexpr int mmq_y = get_mmq_y_device(); + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int ITER_K = get_iter_k(type); + constexpr int blocks_per_iter = ITER_K / qk; - constexpr int blocks_per_iter = ITER_K / qk; - const int64_t blocks_per_ne00 = ncols_x / qk; - - constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int nwarps = mmq_get_nwarps_device()/2; constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; + float sum[mmq_x / nwarps] = {0.0f}; + const int i = blockIdx.y*warp_size + threadIdx.x; - const int ntx = (ncols_max + mmq_x - 1) / mmq_x; - const int nty = (nrows_x + mmq_y - 1) / mmq_y; + const int nty = (nrows_x + mmq_y - 1) / mmq_y; const int bidx0 = blockIdx.x; // kbc == k block continuous, current index in continuous ijk space. - int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; - int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int kbc0 = int64_t(blockIdx.x) *(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; + int kbc0_stop = int64_t(blockIdx.x + 1)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; - kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter; - kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter; + kbc0 -= fastmodulo(kbc0, blocks_per_ne00) % blocks_per_iter; + kbc0_stop -= fastmodulo(kbc0_stop, blocks_per_ne00) % blocks_per_iter; const bool did_not_have_any_data = kbc0 == kbc0_stop; - const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0; - const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0; + const bool wrote_beginning_of_tile = fastmodulo(kbc0, blocks_per_ne00) == 0; + const bool did_not_write_last = fastdiv(kbc0, blocks_per_ne00) == fastdiv(kbc0_stop, blocks_per_ne00) && fastmodulo(kbc0_stop, blocks_per_ne00) != 0; if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { return; } @@ -3765,11 +3758,11 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, // Iterate over previous blocks and sum up partial sums written to fixup buffer. // All CUDA blocks that get here must have a previous block that needs a fixup. - int64_t bidx = bidx0 - 1; - int64_t kbc_stop = kbc0; + int bidx = bidx0 - 1; + int kbc_stop = kbc0; while(true) { - int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; - kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; + int kbc = int64_t(bidx)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; + kbc -= fastmodulo(kbc, blocks_per_ne00) % blocks_per_iter; if (kbc == kbc_stop) { // Did not have any data. bidx--; @@ -3779,20 +3772,16 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, any_fixup = true; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; - } + sum[j0/nwarps] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; } // If this block started in a previous tile we are done and don't need to combine additional partial results. - if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) { + if (fastmodulo(kbc, blocks_per_ne00) == 0 || fastdiv(kbc, blocks_per_ne00) < fastdiv(kbc0, blocks_per_ne00)) { break; } bidx--; @@ -3803,14 +3792,16 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, return; } - int tmp = kbc0; - const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); - tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); - const int zt = tmp / (ntx*blocks_per_ne00); - tmp -= zt * (ntx*blocks_per_ne00); - const int jt = tmp / blocks_per_ne00; + int tmp = fastdiv(kbc0, blocks_per_ne00); + uint2 tmp2 = fast_div_modulo(tmp, ntx); + const int jt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nchannels_y); + const int zt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nsamples_y); + const int wt = tmp2.y; + const int it = tmp2.x; if (!ids_dst) { const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y; @@ -3818,6 +3809,9 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, const int i_max = nrows_x - it*mmq_y - 1; const int j_max = ncols_dst - jt*mmq_x - 1; + if (need_check && i > i_max) { + return; + } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -3827,16 +3821,7 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, return; } -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - if (need_check && i > i_max) { - continue; - } - - dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; - } + dst[j*stride_col_dst + i] += sum[j0/nwarps]; } return; } @@ -3856,6 +3841,9 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, const int i_max = nrows_x - it*mmq_y - 1; const int j_max = col_diff - jt*mmq_x - 1; + if (need_check && i > i_max) { + return; + } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -3865,16 +3853,7 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, return; } -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - if (need_check && i > i_max) { - continue; - } - - dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; - } + dst[ids_dst_shared[j]*stride_col_dst + i] += sum[j0/nwarps]; } } @@ -3922,29 +3901,44 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const int channel_ratio = args.nchannels_y / args.nchannels_x; const int sample_ratio = args.nsamples_y / args.nsamples_x; + const uint3 blocks_per_ne00_fd = init_fastdiv_values(args.ncols_x / ggml_cuda_type_traits::qk); + const uint3 ntx_fd = init_fastdiv_values(ntx); + const uint3 nchannels_y_fd = init_fastdiv_values(args.nchannels_y); + const uint3 nsamples_y_fd = init_fastdiv_values(args.nsamples_y); + const uint3 channel_ratio_fd = init_fastdiv_values(channel_ratio); + const uint3 sample_ratio_fd = init_fastdiv_values(sample_ratio); + if (!args.use_stream_k) { if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); + blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + ntx_fd); } else { constexpr bool need_check = true; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); + blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + ntx_fd); } return; } - const dim3 block_nums_stream_k(nsm, 1, 1); - const bool fixup_needed = ntx*nty*ntzw % nsm != 0; + // For the stream-k kernel it is possible to run it with tiling by setting the number of CUDA blocks equal to the number of tiles. + // This is worthwhile if the efficiency of tiling is high and skipping the fixup kernel is more important. + const int ntiles_dst = ntx * nty * ntzw; + const int tiles_nwaves = (ntiles_dst + nsm - 1) / nsm; + const int tiles_efficiency_percent = 100 * ntiles_dst / (nsm*tiles_nwaves); + const dim3 block_nums_stream_k(GGML_CUDA_CC_IS_NVIDIA(cc) && tiles_efficiency_percent >= 90 ? ntiles_dst : nsm, 1, 1); + + GGML_ASSERT(ntiles_dst * blocks_per_ne00_fd.z < (1 << 30)); // Assert that variable kbc will not overflow. + + const bool fixup_needed = ntiles_dst % block_nums_stream_k.x != 0; ggml_cuda_pool & pool = ctx.pool(id); ggml_cuda_pool_alloc tmp_fixup(pool); @@ -3952,40 +3946,45 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y); } + const dim3 block_nums_fixup(block_nums_stream_k.x, mmq_y/warp_size, 1); + const dim3 block_dims_fixup(block_dims.x, block_dims.y/2, block_dims.z); + if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); + blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + ntx_fd); if (!fixup_needed) { return; } - mul_mat_q_stream_k_fixup<<>> - (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, - args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, - args.ncols_max); + CUDA_CHECK(cudaGetLastError()); + mul_mat_q_stream_k_fixup<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, + args.nrows_dst, nchannels_y_fd, args.stride_channel_dst, nsamples_y_fd, args.stride_sample_dst, + ntx_fd); } else { constexpr bool need_check = true; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); + blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + ntx_fd); if (!fixup_needed) { return; } - mul_mat_q_stream_k_fixup<<>> - (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, - args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, - args.ncols_max); + CUDA_CHECK(cudaGetLastError()); + mul_mat_q_stream_k_fixup<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, + args.nrows_dst, nchannels_y_fd, args.stride_channel_dst, nsamples_y_fd, args.stride_sample_dst, + ntx_fd); } }