Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	tests/test-backend-ops.cpp
This commit is contained in:
Concedo 2025-06-29 15:10:26 +08:00
commit d383c03554
4 changed files with 182 additions and 88 deletions

View file

@ -728,3 +728,25 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
return nullptr;
}
}
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_F32:
return convert_unary_cuda<float, nv_bfloat16>;
case GGML_TYPE_F16:
return convert_unary_cuda<half, nv_bfloat16>;
default:
return nullptr;
}
}
to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_F16:
return convert_unary_cuda<half, float>;
case GGML_TYPE_BF16:
return convert_unary_cuda<nv_bfloat16, float>;
default:
return nullptr;
}
}

View file

@ -22,5 +22,10 @@ using to_t_nc_cuda_t = void (*)(const void * x, T * y,
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);
typedef to_t_nc_cuda_t<float> to_fp32_nc_cuda_t;
typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;
typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type);
to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);

View file

@ -1750,7 +1750,7 @@ static void ggml_cuda_op_mul_mat(
}
static __global__ void k_compute_batched_ptrs(
const half * src0_as_f16, const half * src1_as_f16, char * dst,
const void * src0_as_f16, const void * src1_as_f16, char * dst,
const void ** ptrs_src, void ** ptrs_dst,
int64_t ne12, int64_t ne13,
int64_t ne23,
@ -1773,83 +1773,131 @@ static __global__ void k_compute_batched_ptrs(
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
}
static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
// Type traits for mapping ggml types to CUDA/cuBLAS types
template<ggml_type T>
struct batched_mul_mat_traits;
template<>
struct batched_mul_mat_traits<GGML_TYPE_F32> {
using cuda_type = float;
static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
static inline const cudaDataType_t data_type = CUDA_R_32F;
static inline const ggml_type ggml_type_val = GGML_TYPE_F32;
static inline const float alpha = 1.0f;
static inline const float beta = 0.0f;
static inline const void* get_alpha() { static const float val = alpha; return &val; }
static inline const void* get_beta() { static const float val = beta; return &val; }
static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp32_nc_cuda(src_type); }
};
template<>
struct batched_mul_mat_traits<GGML_TYPE_BF16> {
using cuda_type = nv_bfloat16;
static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
static inline const cudaDataType_t data_type = CUDA_R_16BF;
static inline const ggml_type ggml_type_val = GGML_TYPE_BF16;
static inline const float alpha = 1.0f;
static inline const float beta = 0.0f;
static inline const void* get_alpha() { static const float val = alpha; return &val; }
static inline const void* get_beta() { static const float val = beta; return &val; }
static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_bf16_nc_cuda(src_type); }
};
template<>
struct batched_mul_mat_traits<GGML_TYPE_F16> {
using cuda_type = half;
static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
static inline const cudaDataType_t data_type = CUDA_R_16F;
static inline const ggml_type ggml_type_val = GGML_TYPE_F16;
static inline const half alpha = 1.0;
static inline const half beta = 0.0;
static inline const void* get_alpha() { static const half val = alpha; return &val; }
static inline const void* get_beta() { static const half val = beta; return &val; }
static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); }
};
template<ggml_type src0_type>
static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
using traits = batched_mul_mat_traits<src0_type>;
using cuda_t = typename traits::cuda_type;
GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == src0_type);
GGML_ASSERT(ggml_is_contiguous(dst));
// Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
// As long as dst is contiguous this does not matter though.
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_TENSOR_BINARY_OP_LOCALS
const int64_t ne_dst = ggml_nelements(dst);
cudaStream_t main_stream = ctx.stream();
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
const half * src0_f16 = (const half *) src0->data;
float * dst_ddf = (float *) dst->data;
const half * src1_f16 = (const half *) src1->data;
const size_t ts_src1 = ggml_type_size(src1->type);
GGML_ASSERT(nb10 == ts_src1);
int64_t s11 = nb11 / ts_src1;
int64_t s12 = nb12 / ts_src1;
int64_t s13 = nb13 / ts_src1;
ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
// convert src1 to fp16
if (src1->type != GGML_TYPE_F16) {
const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type);
const cuda_t * src0_ptr = nullptr;
const cuda_t * src1_ptr = nullptr;
ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
// Handle src0
src0_ptr = (const cuda_t *) src0->data;
// Handle src1 - convert if necessary
if (src1->type == src0_type) {
src1_ptr = (const cuda_t *) src1->data;
} else {
// Convert src1 to target type using traits conversion functions
const int64_t ne_src1 = ggml_nelements(src1);
src1_f16_alloc.alloc(ne_src1);
GGML_ASSERT(to_fp16_cuda != nullptr);
src1_alloc.alloc(ne_src1);
to_fp16_cuda(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
src1_f16 = src1_f16_alloc.get();
const auto convert_func = traits::get_nc_converter(src1->type);
GGML_ASSERT(convert_func != nullptr);
convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
src1_ptr = src1_alloc.get();
s11 = ne10;
s12 = ne11*s11;
s13 = ne12*s12;
}
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool());
// Setup destination buffer
ggml_cuda_pool_alloc<cuda_t> dst_temp(ctx.pool());
char * dst_t;
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
cudaDataType_t cu_data_type = CUDA_R_16F;
// dst strides
size_t nbd2 = dst->nb[2];
size_t nbd3 = dst->nb[3];
const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;
cublasComputeType_t cu_compute_type = traits::compute_type;
cudaDataType_t cu_data_type = traits::data_type;
cudaDataType_t cu_data_type_a = traits::data_type;
cudaDataType_t cu_data_type_b = traits::data_type;
const void * alpha = traits::get_alpha();
const void * beta = traits::get_beta();
const float alpha_f32 = 1.0f;
const float beta_f32 = 0.0f;
const void * alpha = &alpha_f16;
const void * beta = &beta_f16;
const float beta_f32 = 0.0f;
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
dst_t = (char *) dst_f16.alloc(ne_dst);
nbd2 /= sizeof(float) / sizeof(half);
nbd3 /= sizeof(float) / sizeof(half);
if constexpr (src0_type == GGML_TYPE_F32) {
dst_t = (char *) dst_ddf; // Direct F32 output
} else {
dst_t = (char *) dst_temp.alloc(ne_dst);
nbd2 /= sizeof(float) / sizeof(cuda_t);
nbd3 /= sizeof(float) / sizeof(cuda_t);
}
} else {
dst_t = (char *) dst_ddf;
cu_compute_type = CUBLAS_COMPUTE_32F;
cu_data_type = CUDA_R_32F;
cu_data_type = CUDA_R_32F;
alpha = &alpha_f32;
beta = &beta_f32;
beta = &beta_f32;
}
int id = ggml_cuda_get_device();
@ -1857,7 +1905,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
cu_compute_type = CUBLAS_COMPUTE_32F;
alpha = &alpha_f32;
beta = &beta_f32;
beta = &beta_f32;
}
GGML_ASSERT(ne12 % ne02 == 0);
@ -1867,35 +1915,15 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;
#if 0
// use cublasGemmEx
{
for (int i13 = 0; i13 < ne13; ++i13) {
for (int i12 = 0; i12 < ne12; ++i12) {
int i03 = i13 / r3;
int i02 = i12 / r2;
CUBLAS_CHECK(
cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
alpha, (const char *) src0_f16 + i03*nb03 + i02*nb02, CUDA_R_16F, nb01/sizeof(half),
src1_f16 + i13*s13 + i12*s12, CUDA_R_16F, s11,
beta, ( char *) dst_t + i13*nbd3 + i12*nbd2, cu_data_type, ne0,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
}
}
#else
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
// use cublasGemmStridedBatchedEx
CUBLAS_CHECK(
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
alpha, src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
src1_f16, CUDA_R_16F, s11, s12, // strideB
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
src1_ptr, cu_data_type_b, s11, s12, // strideB
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
ne12*ne13,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@ -1906,34 +1934,55 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
size_t src1_stride_size = sizeof(cuda_t);
dim3 block_dims(ne13, ne12);
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
src0_f16, src1_f16, dst_t,
src0_ptr, src1_ptr, dst_t,
ptrs_src.get(), ptrs_dst.get(),
ne12, ne13,
ne23,
nb02, nb03,
src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half),
src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half),
(src1->type == src0_type) ? nb12 : s12*src1_stride_size,
(src1->type == src0_type) ? nb13 : s13*src1_stride_size,
nbd2, nbd3,
r2, r3);
CUDA_CHECK(cudaGetLastError());
CUBLAS_CHECK(
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11,
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
(const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
ne23,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
#endif
if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
// Convert output back to F32 if needed
if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) {
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val);
to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, main_stream);
}
}
static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
switch (src0->type) {
case GGML_TYPE_F32:
ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst);
break;
case GGML_TYPE_BF16:
ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst);
break;
case GGML_TYPE_F16:
ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst);
break;
default:
GGML_ABORT("Unsupported type");
}
}
@ -1985,6 +2034,12 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
//TODO update for generic tensor parallelism
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
if (!split && use_mul_mat_vec) {
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
@ -1993,8 +2048,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
} else if (!split && use_mul_mat_q) {
ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
!ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
} else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32)
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
// general KQ + KQV multi-batch without FlashAttention
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
} else if (use_mul_mat_vec) {

View file

@ -321,7 +321,7 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
}
struct vk_device_struct {
std::mutex mutex;
std::recursive_mutex mutex;
vk::PhysicalDevice physical_device;
vk::PhysicalDeviceProperties properties;
@ -1213,7 +1213,7 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
}
{
std::lock_guard<std::mutex> guard(device->mutex);
std::lock_guard<std::recursive_mutex> guard(device->mutex);
device->pipelines.insert({ pipeline->name, pipeline });
}
@ -1427,7 +1427,7 @@ static uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyPrope
static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) {
VK_LOG_DEBUG("ggml_vk_create_queue()");
std::lock_guard<std::mutex> guard(device->mutex);
std::lock_guard<std::recursive_mutex> guard(device->mutex);
q.queue_family_index = queue_family_index;
q.transfer_only = transfer_only;
@ -4148,6 +4148,7 @@ static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
return nullptr;
}
std::lock_guard<std::recursive_mutex> guard(device->mutex);
device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf));
return buf->ptr;
@ -4158,6 +4159,8 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
return;
}
VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")");
std::lock_guard<std::recursive_mutex> guard(device->mutex);
vk_buffer buf;
size_t index;
for (size_t i = 0; i < device->pinned_memory.size(); i++) {
@ -4180,6 +4183,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
}
static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
std::lock_guard<std::recursive_mutex> guard(device->mutex);
buf = nullptr;
buf_offset = 0;
for (size_t i = 0; i < device->pinned_memory.size(); i++) {
@ -4481,7 +4485,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
}
} else {
std::lock_guard<std::mutex> guard(dst->device->mutex);
std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(dst->device, subctx);
@ -4572,7 +4576,7 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_
memcpy(dst, (uint8_t *) src->ptr + offset, size);
} else {
std::lock_guard<std::mutex> guard(src->device->mutex);
std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(src->device, subctx);
@ -4602,7 +4606,7 @@ static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t ds
static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
if (src->device == dst->device) {
std::lock_guard<std::mutex> guard(src->device->mutex);
std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")");
// Copy within the device
vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
@ -4637,7 +4641,7 @@ static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t
static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
std::lock_guard<std::mutex> guard(dst->device->mutex);
std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(dst->device, subctx);
subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
@ -4864,9 +4868,17 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
// type size must be exactly 2 or 4.
GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4);
if ((ggml_type_size(src->type) % 4) == 0) {
return ctx->device->pipeline_contig_cpy_f32_f32;
if (contig) {
return ctx->device->pipeline_contig_cpy_f32_f32;
} else {
return ctx->device->pipeline_cpy_f32_f32;
}
} else {
return ctx->device->pipeline_contig_cpy_f16_f16;
if (contig) {
return ctx->device->pipeline_contig_cpy_f16_f16;
} else {
return ctx->device->pipeline_cpy_f16_f16;
}
}
}
@ -4927,7 +4939,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
const uint64_t ne00 = src0->ne[0];
@ -5155,7 +5167,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
std::cerr << "), " << (dryrun ? "dryrun" : "") << "),)");
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
const uint64_t ne00 = src0->ne[0];
@ -5756,7 +5768,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
GGML_ASSERT(ids->type == GGML_TYPE_I32);