CUDA: missing PDL sync for FWHT, better fallback (#23690)

This commit is contained in:
Johannes Gäßler 2026-05-26 05:05:51 +02:00 committed by GitHub
parent 35c9b1f39e
commit 192d8ae8b8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 17 additions and 25 deletions

View file

@ -19,6 +19,7 @@ __global__ void fwht_cuda(const float * src, float * dst, const int64_t n_rows,
float reg[el_w];
const int lane = threadIdx.x;
ggml_cuda_pdl_sync();
#pragma unroll
for (int i = 0; i < el_w; ++i) {
reg[i] = src[i * warp_size + lane] * scale;
@ -57,10 +58,11 @@ __global__ void fwht_cuda(const float * src, float * dst, const int64_t n_rows,
}
}
void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) {
bool ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) {
GGML_ASSERT(ggml_are_same_shape(src, dst));
GGML_ASSERT(ggml_is_contiguous(src));
GGML_ASSERT(ggml_is_contiguous(dst));
if (!ggml_is_contiguous(src) || !ggml_is_contiguous(dst)) {
return false;
}
const int n = src->ne[0];
const int64_t rows = ggml_nrows(src);
@ -68,7 +70,6 @@ void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src,
float * dst_d = (float *) dst->data;
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
GGML_ASSERT(n % warp_size == 0);
const int rows_per_block = 4;
const int64_t num_blocks = (rows + rows_per_block - 1) / rows_per_block;
@ -83,26 +84,18 @@ void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src,
switch (n) {
case 64:
{
ggml_cuda_kernel_launch(fwht_cuda<64>, launch_params, src_d, dst_d, rows, scale);
break;
}
ggml_cuda_kernel_launch(fwht_cuda<64>, launch_params, src_d, dst_d, rows, scale);
return true;
case 128:
{
ggml_cuda_kernel_launch(fwht_cuda<128>, launch_params, src_d, dst_d, rows, scale);
break;
}
ggml_cuda_kernel_launch(fwht_cuda<128>, launch_params, src_d, dst_d, rows, scale);
return true;
case 256:
{
ggml_cuda_kernel_launch(fwht_cuda<256>, launch_params, src_d, dst_d, rows, scale);
break;
}
ggml_cuda_kernel_launch(fwht_cuda<256>, launch_params, src_d, dst_d, rows, scale);
return true;
case 512:
{
ggml_cuda_kernel_launch(fwht_cuda<512>, launch_params, src_d, dst_d, rows, scale);
break;
}
ggml_cuda_kernel_launch(fwht_cuda<512>, launch_params, src_d, dst_d, rows, scale);
return true;
default:
GGML_ABORT("fatal error");
return false;
}
}

View file

@ -1,3 +1,4 @@
#include "common.cuh"
void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst);
// Returns whether the Fast Walsh-Hadamard transform could be used.
bool ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst);

View file

@ -2596,9 +2596,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
const int32_t hint = ggml_get_op_params_i32(dst, 1);
if (hint == GGML_HINT_SRC0_IS_HADAMARD) {
GGML_ASSERT(!split);
ggml_cuda_op_fwht(ctx, src1, dst);
if (hint == GGML_HINT_SRC0_IS_HADAMARD && !split && ggml_cuda_op_fwht(ctx, src1, dst)) {
return;
}