Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	CODEOWNERS
#	docs/build.md
#	scripts/sync-ggml.last
#	tests/test-backend-ops.cpp
#	tools/imatrix/README.md
#	tools/imatrix/imatrix.cpp
This commit is contained in:
Concedo 2025-07-20 22:47:31 +08:00
commit 30675b0798
21 changed files with 1397 additions and 1181 deletions

View file

@ -344,6 +344,7 @@ struct vk_device_struct {
uint64_t max_memory_allocation_size;
uint64_t suballocation_block_size;
bool fp16;
bool bf16;
bool pipeline_robustness;
vk::Device device;
uint32_t vendor_id;
@ -498,6 +499,7 @@ struct vk_device_struct {
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_opt_step_adamw_f32;
vk_pipeline pipeline_conv2d_f32;
vk_pipeline pipeline_conv2d_dw_whcn_f32;
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
@ -891,6 +893,38 @@ struct vk_op_rwkv_wkv7_push_constants {
uint32_t H;
};
struct vk_op_conv2d_push_constants {
uint32_t Cout;
uint32_t Cin;
uint32_t N;
uint32_t KW;
uint32_t KH;
uint32_t W;
uint32_t H;
uint32_t OW;
uint32_t OH;
uint32_t s0;
uint32_t s1;
uint32_t p0;
uint32_t p1;
uint32_t d0;
uint32_t d1;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;
};
struct vk_op_conv2d_dw_push_constants {
uint32_t ne;
uint32_t batches;
@ -990,18 +1024,45 @@ private:
#endif // GGML_VULKAN_MEMORY_DEBUG
class vk_perf_logger {
public:
public:
void print_timings() {
if (timings.empty()) {
return;
}
uint64_t total_all_op_times = 0;
std::cerr << "----------------\nVulkan Timings:" << std::endl;
for (const auto& t : timings) {
uint64_t total = 0;
for (const auto& time : t.second) {
total += time;
for (const auto & t : timings) {
uint64_t total_op_times = 0;
for (const auto & time : t.second) {
total_op_times += time;
}
std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " us" << std::endl;
std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0)
<< " us";
// If we have as many flops entries as timing entries for the op, then compute and log the flops/S.
auto it = flops.find(t.first);
if (it != flops.end() && (it->second).size() == t.second.size()) {
uint64_t total_op_flops = 0;
for (const auto & elem : it->second) {
total_op_flops += elem;
}
std::cerr << " ("
<< (double(total_op_flops) / (1000.0 * 1000.0 * 1000.0)) /
(double(total_op_times) / (1000.0 * 1000.0 * 1000.0))
<< " GFLOPS/s)";
}
total_all_op_times += total_op_times;
std::cerr << std::endl;
}
if (timings.size() > 0) {
std::cerr << "Total time: " << total_all_op_times / 1000.0 << " us." << std::endl;
}
timings.clear();
flops.clear();
}
void log_timing(const ggml_tensor * node, uint64_t time) {
@ -1010,22 +1071,45 @@ public:
return;
}
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
const uint64_t m = node->src[0]->ne[1];
const uint64_t n = node->src[1]->ne[1];
const uint64_t k = node->src[1]->ne[0];
std::string name = ggml_op_name(node->op);
const uint64_t m = node->src[0]->ne[1];
const uint64_t n = node->src[1]->ne[1];
const uint64_t k = node->src[1]->ne[0];
std::string name = ggml_op_name(node->op);
if (n == 1) {
name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k);
} else {
name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k);
}
timings[name].push_back(time);
flops[name].push_back(m * n * (k + (k - 1)));
return;
}
if (node->op == GGML_OP_CONV_2D) {
std::string name = ggml_op_name(node->op);
ggml_tensor * knl = node->src[0];
uint64_t OW = node->ne[0];
uint64_t OH = node->ne[1];
uint64_t N = node->ne[3];
uint64_t Cout = node->ne[2];
uint64_t KW = knl->ne[0];
uint64_t KH = knl->ne[1];
uint64_t Cin = knl->ne[2];
// KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ
uint64_t size_M = Cout;
uint64_t size_K = Cin * KW * KH;
uint64_t size_N = N * OW * OH;
uint64_t n_flops = size_M * size_N * (size_K + (size_K - 1));
name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
", N=N*OW*OH=" + std::to_string(size_N);
flops[name].push_back(n_flops);
timings[name].push_back(time);
return;
}
timings[ggml_op_name(node->op)].push_back(time);
}
private:
private:
std::map<std::string, std::vector<uint64_t>> timings;
std::map<std::string, std::vector<uint64_t>> flops;
};
struct ggml_backend_vk_context {
@ -2128,6 +2212,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
compile_count++;
}
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
};
@ -2977,6 +3062,42 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
// conv2d
uint32_t conv2d_WG_SIZE = 256;
uint32_t conv2d_BS_K = 128;
uint32_t conv2d_BS_CRS = 16;
uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
if (device->subgroup_shuffle &&
device->vendor_id != VK_VENDOR_ID_INTEL) { // Do not enable collectives on Intel, see PR 14316
use_collectives = 1;
conv2d_BS_CRS = std::min(
device->subgroup_size,
conv2d_BS_CRS); // CRS block size should be capped at sugroup size for correctness when shuffle is used.
}
uint32_t conv2d_BS_NPQ = 128;
uint32_t conv2d_TS_K = 8;
uint32_t conv2d_shmem_req =
(conv2d_BS_K * (conv2d_BS_CRS + 1) + conv2d_BS_CRS * (conv2d_BS_NPQ + 1)) * sizeof(float);
if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
conv2d_BS_CRS = 8;
if (use_collectives) {
conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
}
}
if (use_collectives) {
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
} else {
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
false);
}
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@ -3297,6 +3418,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
#if defined(VK_KHR_shader_bfloat16)
device->bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type;
#else
device->bf16 = false;
#endif
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
if (device->subgroup_size_control) {
@ -3639,6 +3766,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
bool coopmat_support = false;
bool coopmat2_support = false;
bool integer_dot_product = false;
bool bfloat16_support = false;
for (auto properties : ext_props) {
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
@ -3659,6 +3787,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
integer_dot_product = true;
#endif
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
} else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
bfloat16_support = true;
#endif
}
}
@ -3725,10 +3858,25 @@ static void ggml_vk_print_gpu_info(size_t idx) {
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
}
#if defined(VK_KHR_shader_bfloat16)
VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
if (bfloat16_support) {
last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
last_struct = (VkBaseOutStructure *)&bfloat16_features;
}
#endif
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
fp16 = fp16 && vk12_features.shaderFloat16;
#if defined(VK_KHR_shader_bfloat16)
bool bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type;
#else
bool bf16 = false;
#endif
uint32_t default_subgroup_size = get_subgroup_size("", device_architecture);
const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
@ -3746,8 +3894,8 @@ static void ggml_vk_print_gpu_info(size_t idx) {
std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
std::string device_name = props2.properties.deviceName.data();
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size,
props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
@ -6833,6 +6981,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_leaky_relu_f32;
}
return nullptr;
case GGML_OP_CONV_2D:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
return ctx->device->pipeline_conv2d_f32;
}
return nullptr;
case GGML_OP_CONV_2D_DW:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
if (ggml_is_contiguous(src1)) {
@ -7155,6 +7309,31 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t OW = dst->ne[0];
elements = { N * OC * OH * OW, 1, 1};
} break;
case GGML_OP_CONV_2D:
{
// src0 - kernel: [KW, KH, Cin, Cout]
// src1 - input: [W, H, Cin, N]
// dst - result: [OW, OH, Cout, N]
// Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
};
// parallelize in {OW/BS_K, OH/BS_NPQ, 1}
int64_t W = src1->ne[0];
int64_t H = src1->ne[1];
int64_t KW = src0->ne[0];
int64_t KH = src0->ne[1];
int64_t Cout = src0->ne[3];
int64_t N = src1->ne[3];
int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]);
int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]);
int64_t NPQ = N * OW * OH;
// Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
}
break;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:
@ -8021,6 +8200,55 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
}, dryrun);
}
static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(nb00 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float));
GGML_ASSERT(nb0 == sizeof(float));
vk_op_conv2d_push_constants p{};
p.Cout = static_cast<uint32_t>(ne03);
p.Cin = static_cast<uint32_t>(ne02);
p.N = static_cast<uint32_t>(ne13);
p.KW = static_cast<uint32_t>(ne00);
p.KH = static_cast<uint32_t>(ne01);
p.W = static_cast<uint32_t>(ne10);
p.H = static_cast<uint32_t>(ne11);
p.OW = static_cast<uint32_t>(ne0);
p.OH = static_cast<uint32_t>(ne1);
p.s0 = static_cast<uint32_t>(dst->op_params[0]);
p.s1 = static_cast<uint32_t>(dst->op_params[1]);
p.p0 = static_cast<uint32_t>(dst->op_params[2]);
p.p1 = static_cast<uint32_t>(dst->op_params[3]);
p.d0 = static_cast<uint32_t>(dst->op_params[4]);
p.d1 = static_cast<uint32_t>(dst->op_params[5]);
p.nb01 = static_cast<uint32_t>(nb01 / nb00);
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
p.nb03 = static_cast<uint32_t>(nb03 / nb00);
p.nb11 = static_cast<uint32_t>(nb11 / nb10);
p.nb12 = static_cast<uint32_t>(nb12 / nb10);
p.nb13 = static_cast<uint32_t>(nb13 / nb10);
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
p.nb3 = static_cast<uint32_t>(nb3 / nb0);
GGML_ASSERT(ne03 == ne2);
GGML_ASSERT(ne02 == ne12);
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun);
}
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
vk_op_conv2d_dw_push_constants p{};
p.ne = ggml_nelements(dst);
@ -9083,6 +9311,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_POOL_2D:
case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
@ -9150,6 +9379,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_POOL_2D:
case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_LEAKY_RELU:
{
@ -9356,6 +9586,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_POOL_2D:
ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
break;
case GGML_OP_CONV_2D:
ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node, dryrun);
break;
case GGML_OP_CONV_2D_DW:
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
@ -9486,6 +9720,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_POOL_2D:
case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
@ -10067,6 +10302,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
} else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D) {
// Return CRSxNPQxsizeof(*) to account as many bytes as mul_mat has in im2col->mul_mat mode.
auto CRS_size =
cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[0]->ne[2];
auto NPQ_size = cgraph->nodes[i]->ne[0] * cgraph->nodes[i]->ne[1] * cgraph->nodes[i]->ne[3];
total_mat_mul_bytes += NPQ_size * CRS_size * ggml_type_size(cgraph->nodes[i]->type);
}
i += ctx->num_additional_fused_ops;
ctx->num_additional_fused_ops = 0;
@ -10643,6 +10884,20 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return true;
case GGML_OP_CONV_TRANSPOSE_1D:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_CONV_2D:
{
// Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
const vk_device& device = ggml_vk_get_device(ctx->device);
bool is_Apple = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_APPLE;
// Channel-contiguous format is not supported yet.
return (op->src[0]->type == GGML_TYPE_F32 &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32 &&
ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op->src[1]) &&
ggml_is_contiguous(op)) && !is_Apple;
}
default:
return false;
}
@ -11201,6 +11456,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
const int32_t p1 = tensor->op_params[6];
tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1);
} else if (tensor->op == GGML_OP_CONV_2D) {
const int32_t s0 = tensor->op_params[0];
const int32_t s1 = tensor->op_params[1];
const int32_t p0 = tensor->op_params[2];
const int32_t p1 = tensor->op_params[3];
const int32_t d0 = tensor->op_params[4];
const int32_t d1 = tensor->op_params[5];
tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
const float * op_params = (const float *)tensor->op_params;
tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);