diff --git a/ggml/src/ggml-cuda/fwht.cu b/ggml/src/ggml-cuda/fwht.cu index 74e94d844..184dc254c 100644 --- a/ggml/src/ggml-cuda/fwht.cu +++ b/ggml/src/ggml-cuda/fwht.cu @@ -19,6 +19,7 @@ __global__ void fwht_cuda(const float * src, float * dst, const int64_t n_rows, float reg[el_w]; const int lane = threadIdx.x; + ggml_cuda_pdl_sync(); #pragma unroll for (int i = 0; i < el_w; ++i) { reg[i] = src[i * warp_size + lane] * scale; @@ -57,10 +58,11 @@ __global__ void fwht_cuda(const float * src, float * dst, const int64_t n_rows, } } -void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) { +bool ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) { GGML_ASSERT(ggml_are_same_shape(src, dst)); - GGML_ASSERT(ggml_is_contiguous(src)); - GGML_ASSERT(ggml_is_contiguous(dst)); + if (!ggml_is_contiguous(src) || !ggml_is_contiguous(dst)) { + return false; + } const int n = src->ne[0]; const int64_t rows = ggml_nrows(src); @@ -68,7 +70,6 @@ void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, float * dst_d = (float *) dst->data; const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; - GGML_ASSERT(n % warp_size == 0); const int rows_per_block = 4; const int64_t num_blocks = (rows + rows_per_block - 1) / rows_per_block; @@ -83,26 +84,18 @@ void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, switch (n) { case 64: - { - ggml_cuda_kernel_launch(fwht_cuda<64>, launch_params, src_d, dst_d, rows, scale); - break; - } + ggml_cuda_kernel_launch(fwht_cuda<64>, launch_params, src_d, dst_d, rows, scale); + return true; case 128: - { - ggml_cuda_kernel_launch(fwht_cuda<128>, launch_params, src_d, dst_d, rows, scale); - break; - } + ggml_cuda_kernel_launch(fwht_cuda<128>, launch_params, src_d, dst_d, rows, scale); + return true; case 256: - { - ggml_cuda_kernel_launch(fwht_cuda<256>, launch_params, src_d, dst_d, rows, scale); - break; - } + ggml_cuda_kernel_launch(fwht_cuda<256>, launch_params, src_d, dst_d, rows, scale); + return true; case 512: - { - ggml_cuda_kernel_launch(fwht_cuda<512>, launch_params, src_d, dst_d, rows, scale); - break; - } + ggml_cuda_kernel_launch(fwht_cuda<512>, launch_params, src_d, dst_d, rows, scale); + return true; default: - GGML_ABORT("fatal error"); + return false; } } diff --git a/ggml/src/ggml-cuda/fwht.cuh b/ggml/src/ggml-cuda/fwht.cuh index fa4c30477..cf3df94ca 100644 --- a/ggml/src/ggml-cuda/fwht.cuh +++ b/ggml/src/ggml-cuda/fwht.cuh @@ -1,3 +1,4 @@ #include "common.cuh" -void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst); +// Returns whether the Fast Walsh-Hadamard transform could be used. +bool ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 1bb09ac80..23d1c0692 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2596,9 +2596,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32; const int32_t hint = ggml_get_op_params_i32(dst, 1); - if (hint == GGML_HINT_SRC0_IS_HADAMARD) { - GGML_ASSERT(!split); - ggml_cuda_op_fwht(ctx, src1, dst); + if (hint == GGML_HINT_SRC0_IS_HADAMARD && !split && ggml_cuda_op_fwht(ctx, src1, dst)) { return; }