ggml: add conv3d op (#15182)

* add conv3d

* bump GGML_OP_COUNT
This commit is contained in:
rmatif 2025-08-22 15:33:15 +02:00 committed by GitHub
parent b1ab91821f
commit 92f7f0a53c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 345 additions and 2 deletions

View file

@ -512,6 +512,7 @@ extern "C" {
GGML_OP_IM2COL,
GGML_OP_IM2COL_BACK,
GGML_OP_CONV_2D,
GGML_OP_CONV_3D,
GGML_OP_CONV_2D_DW,
GGML_OP_CONV_TRANSPOSE_2D,
GGML_OP_POOL_1D,
@ -1940,6 +1941,23 @@ extern "C" {
int d0, // dilation dimension 0
int d1); // dilation dimension 1
GGML_API struct ggml_tensor * ggml_conv_3d(
struct ggml_context * ctx,
struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC]
struct ggml_tensor * b, // input [W, H, D, C * N]
int s0, // stride
int s1,
int s2,
int p0, // padding
int p1,
int p2,
int d0, // dilation
int d1,
int d2,
int n_channels,
int n_batch,
int n_channels_out);
enum ggml_op_pool {
GGML_OP_POOL_MAX,
GGML_OP_POOL_AVG,

View file

@ -1880,6 +1880,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_conv_2d(params, tensor);
} break;
case GGML_OP_CONV_3D:
{
ggml_compute_forward_conv_3d(params, tensor);
} break;
case GGML_OP_CONV_2D_DW:
{
ggml_compute_forward_conv_2d_dw(params, tensor);
@ -2252,6 +2256,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_IM2COL:
case GGML_OP_IM2COL_BACK:
case GGML_OP_CONV_2D:
case GGML_OP_CONV_3D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_CONV_TRANSPOSE_2D:
@ -2773,6 +2778,7 @@ struct ggml_cplan ggml_graph_plan(
}
} break;
case GGML_OP_CONV_2D:
case GGML_OP_CONV_3D:
{
cur = GGML_IM2COL_WORK_SIZE;
} break;

View file

@ -7207,6 +7207,148 @@ void ggml_compute_forward_conv_2d(
ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
}
// ggml_compute_forward_conv_3d
static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,
const ggml_tensor * kernel,
const ggml_tensor * src,
ggml_tensor * dst,
ggml_type kernel_type) {
GGML_ASSERT(ggml_is_contiguous(kernel));
GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
GGML_ASSERT(kernel->type == kernel_type);
const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
const int32_t s0 = dst->op_params[0];
const int32_t s1 = dst->op_params[1];
const int32_t s2 = dst->op_params[2];
const int32_t p0 = dst->op_params[3];
const int32_t p1 = dst->op_params[4];
const int32_t p2 = dst->op_params[5];
const int32_t d0 = dst->op_params[6];
const int32_t d1 = dst->op_params[7];
const int32_t d2 = dst->op_params[8];
const int32_t c = dst->op_params[9];
const int32_t n = dst->op_params[10];
const int32_t oc = dst->op_params[11];
const int64_t src_w = src->ne[0];
const int64_t src_h = src->ne[1];
const int64_t src_d = src->ne[2];
const int64_t knl_w = kernel->ne[0];
const int64_t knl_h = kernel->ne[1];
const int64_t knl_d = kernel->ne[2];
const int64_t dst_w = dst->ne[0];
const int64_t dst_h = dst->ne[1];
const int64_t dst_d = dst->ne[2];
const float * src_data = (float *) src->data;
void * knl_data = kernel->data;
float * dst_data = (float *) dst->data;
const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
const int64_t knl_n_total = knl_n_per_channel * c;
const int64_t patch_total = n * dst_w * dst_h * dst_d;
const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float);
const int64_t batch_size = params->wsize / space_per_patch;
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
void * tmp = params->wdata;
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
const int64_t patch_start_batch = batch_i * patches_per_batch;
const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, patch_total);
const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch;
const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
for (int64_t p = patch_start; p < patch_end; ++p) {
const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
const int64_t dst_y = p_in_depth / dst_w;
const int64_t dst_x = p_in_depth % dst_w;
char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
for (int64_t ic = 0; ic < c; ++ic) {
for (int64_t kz = 0; kz < knl_d; ++kz) {
for (int64_t ky = 0; ky < knl_h; ++ky) {
for (int64_t kx = 0; kx < knl_w; ++kx) {
const int64_t sz = dst_z * s2 + kz * d2 - p2;
const int64_t sy = dst_y * s1 + ky * d1 - p1;
const int64_t sx = dst_x * s0 + kx * d0 - p0;
int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
float src_val;
if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
src_val = 0.0f;
} else {
const int64_t cn_idx = batch_idx * c + ic;
const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
src_val = *src_ptr;
}
char * element_ptr = dst_row + dst_idx * traits->type_size;
if (kernel_type == GGML_TYPE_F32) {
*(float *)element_ptr = src_val;
} else if (kernel_type == GGML_TYPE_F16) {
*(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
}
}
}
}
}
}
ggml_barrier(params->threadpool);
float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
ggml_barrier(params->threadpool);
const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
const int64_t permute_start = params->ith * permute_per_thread;
const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
for (int64_t i = permute_start; i < permute_end; ++i) {
const int64_t p = patch_start_batch + i;
const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
const int64_t dst_y = p_in_depth / dst_w;
const int64_t dst_x = p_in_depth % dst_w;
for (int64_t ioc = 0; ioc < oc; ++ioc) {
const float value = gemm_output[i * oc + ioc];
const int64_t ocn_idx = batch_idx * oc + ioc;
float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
*dst_ptr = value;
}
}
}
}
void ggml_compute_forward_conv_3d(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
}
// ggml_compute_forward_conv_transpose_2d
void ggml_compute_forward_conv_transpose_2d(

View file

@ -70,6 +70,7 @@ void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * p
void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);

View file

@ -975,6 +975,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"IM2COL",
"IM2COL_BACK",
"CONV_2D",
"CONV_3D",
"CONV_2D_DW",
"CONV_TRANSPOSE_2D",
"POOL_1D",
@ -1017,7 +1018,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU",
};
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -1077,6 +1078,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"im2col(x)",
"im2col_back(x)",
"conv_2d(x)",
"conv_3d(x)",
"conv_2d_dw(x)",
"conv_transpose_2d(x)",
"pool_1d(x)",
@ -1119,7 +1121,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x)",
};
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -4480,6 +4482,56 @@ struct ggml_tensor * ggml_conv_2d_direct(
return result;
}
// ggml_conv_3d
struct ggml_tensor * ggml_conv_3d(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int s0,
int s1,
int s2,
int p0,
int p1,
int p2,
int d0,
int d1,
int d2,
int c,
int n,
int oc) {
GGML_ASSERT(a->ne[3] == (int64_t) c * oc);
GGML_ASSERT(b->ne[3] == (int64_t) c * n);
int64_t ne[4];
ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
ne[2] = ggml_calc_conv_output_size(b->ne[2], a->ne[2], s2, p2, d2);
ne[3] = (int64_t) oc * n;
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
ggml_set_op_params_i32(result, 0, s0);
ggml_set_op_params_i32(result, 1, s1);
ggml_set_op_params_i32(result, 2, s2);
ggml_set_op_params_i32(result, 3, p0);
ggml_set_op_params_i32(result, 4, p1);
ggml_set_op_params_i32(result, 5, p2);
ggml_set_op_params_i32(result, 6, d0);
ggml_set_op_params_i32(result, 7, d1);
ggml_set_op_params_i32(result, 8, d2);
ggml_set_op_params_i32(result, 9, c);
ggml_set_op_params_i32(result, 10, n);
ggml_set_op_params_i32(result, 11, oc);
result->op = GGML_OP_CONV_3D;
result->src[0] = a;
result->src[1] = b;
return result;
}
// ggml_conv_transpose_2d_p0
static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {

View file

@ -4091,6 +4091,75 @@ struct test_conv_2d_dw : public test_case {
}
};
// GGML_OP_CONV_3D
struct test_conv_3d : public test_case {
// Logical 5D dimensions
const int64_t N, IC, ID, IH, IW;
const int64_t OC, KD, KH, KW;
// Conv params
const int s0, s1, s2;
const int p0, p1, p2;
const int d0, d1, d2;
// Types
const ggml_type type_kernel;
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
return "CONV_3D";
}
std::string vars() override {
return VARS_TO_STR11(N, IC, ID, IH, IW, OC, KD, KH, KW, s0, s1) + "," +
VARS_TO_STR8(s2, p0, p1, p2, d0, d1, d2, type_kernel);
}
double max_nmse_err() override {
return 5e-4;
}
uint64_t op_flops(ggml_tensor * t) override {
GGML_UNUSED(t);
auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
};
const int64_t OD = calc_conv_output_size(ID, KD, s2, p2, d2);
const int64_t OH = calc_conv_output_size(IH, KH, s1, p1, d1);
const int64_t OW = calc_conv_output_size(IW, KW, s0, p0, d0);
return (uint64_t)N * OC * OD * OH * OW * (2 * IC * KD * KH * KW - 1);
}
test_conv_3d(
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW,
int64_t OC, int64_t KD, int64_t KH, int64_t KW,
int s0, int s1, int s2,
int p0, int p1, int p2,
int d0, int d1, int d2,
ggml_type type_kernel
) : N(N), IC(IC), ID(ID), IH(IH), IW(IW),
OC(OC), KD(KD), KH(KH), KW(KW),
s0(s0), s1(s1), s2(s2),
p0(p0), p1(p1), p2(p2),
d0(d0), d1(d1), d2(d2),
type_kernel(type_kernel) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
// GGML input tensor is packed as [W, H, D, C*N]
const int64_t ne_input[] = {IW, IH, ID, IC * N};
ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input);
ggml_set_name(input, "input");
// GGML kernel tensor is packed as [KW, KH, KD, IC*OC]
const int64_t ne_kernel[] = {KW, KH, KD, IC * OC};
ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel);
ggml_set_name(kernel, "kernel");
ggml_tensor * out = ggml_conv_3d(ctx, kernel, input, s0, s1, s2, p0, p1, p2, d0, d1, d2, (int)IC, (int)N, (int)OC);
ggml_set_name(out, "out");
return out;
}
};
// GGML_OP_CONCAT
struct test_concat : public test_case {
const ggml_type type;
@ -5528,6 +5597,61 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, false));
test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, true));
// CONV_3D
auto calc_conv_output_size_3d = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
};
for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
for (int N : {1, 2}) {
for (int IC : {1, 3}) {
for (int OC : {1, 4}) {
for (int s0 : {1, 2}) {
for (int p1 : {0, 1}) {
for (int d2 : {1, 2}) {
int64_t IW = 20, IH = 22, ID = 18;
int64_t KW = 3, KH = 3, KD = 3;
int s1 = s0, s2 = s0;
int p0 = p1, p2 = p1;
int d0 = d2, d1 = d2;
if (calc_conv_output_size_3d(IW, KW, s0, p0, d0) <= 0 ||
calc_conv_output_size_3d(IH, KH, s1, p1, d1) <= 0 ||
calc_conv_output_size_3d(ID, KD, s2, p2, d2) <= 0) {
continue;
}
test_cases.emplace_back(new test_conv_3d(
N, IC, ID, IH, IW,
OC, KD, KH, KW,
s0, s1, s2, p0, p1, p2, d0, d1, d2,
kernel_type));
// Asymmetric kernel and params
int64_t asym_KW = 5, asym_KH = 1, asym_KD = 3;
int asym_s0 = 2, asym_s1 = 1, asym_s2 = 1;
int asym_p0 = 2, asym_p1 = 0, asym_p2 = 1;
int asym_d0 = 1, asym_d1 = 1, asym_d2 = 2;
if (calc_conv_output_size_3d(IW, asym_KW, asym_s0, asym_p0, asym_d0) <= 0 ||
calc_conv_output_size_3d(IH, asym_KH, asym_s1, asym_p1, asym_d1) <= 0 ||
calc_conv_output_size_3d(ID, asym_KD, asym_s2, asym_p2, asym_d2) <= 0) {
continue;
}
test_cases.emplace_back(new test_conv_3d(
N, IC, ID, IH, IW,
OC, asym_KD, asym_KH, asym_KW,
asym_s0, asym_s1, asym_s2, asym_p0, asym_p1, asym_p2, asym_d0, asym_d1, asym_d2,
kernel_type));
}
}
}
}
}
}
// Case with kernel size 1
test_cases.emplace_back(new test_conv_3d(1, 4, 8, 8, 8, 8, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, kernel_type));
}
for(uint32_t Cout : {1, 9}){
for(uint32_t Cin : {1, 7}){
for(uint32_t K : {1, 3, 1337}){