Merge remote-tracking branch 'jeff/im2col_wglimit' into concedo_experimental

# Conflicts:
#	tests/test-backend-ops.cpp
This commit is contained in:
Concedo 2025-12-19 11:01:47 +08:00
commit fef2ea46fd
2 changed files with 26 additions and 6 deletions

View file

@ -1274,6 +1274,7 @@ struct vk_op_im2col_push_constants {
int32_t s0; int32_t s1;
int32_t p0; int32_t p1;
int32_t d0; int32_t d1;
uint32_t batch_IC;
};
struct vk_op_im2col_3d_push_constants {
@ -5900,6 +5901,9 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
}
std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
GGML_ASSERT(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
@ -9086,6 +9090,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
elements = { OW * KW * KH, OH, batch * IC };
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
} break;
case GGML_OP_IM2COL_3D:
{
@ -10589,6 +10595,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
const uint32_t pelements = OW * KW * KH;
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
const vk_buffer d_buf = d_buf_ctx->dev_buffer;
@ -10601,7 +10608,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
IC, IW, IH, OW, OH, KW, KH,
pelements,
IC * KH * KW,
s0, s1, p0, p1, d0, d1,
s0, s1, p0, p1, d0, d1, batch * IC
});
}
@ -14715,7 +14722,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
} else if (tensor->op == GGML_OP_LOG) {
tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_TRI) {
tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
tensor_clone = ggml_tri(ggml_ctx, src_clone[0], (ggml_tri_type)ggml_get_op_params_i32(tensor, 0));
} else if (tensor->op == GGML_OP_DIAG) {
tensor_clone = ggml_diag(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_CLAMP) {

View file

@ -19,6 +19,7 @@ layout (push_constant) uniform parameter
int s0; int s1;
int p0; int p1;
int d0; int d1;
uint batch_IC;
} p;
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
@ -34,12 +35,12 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
#endif
void main() {
void im2col(const uint y, const uint z) {
const uint gidx = gl_GlobalInvocationID.x;
const uint oh = gl_GlobalInvocationID.y;
const uint batch = gl_GlobalInvocationID.z / p.IC;
const uint ic = gl_GlobalInvocationID.z % p.IC;
const uint oh = y;
const uint batch = z / p.IC;
const uint ic = z % p.IC;
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
@ -101,3 +102,15 @@ void main() {
#endif
}
}
void main() {
uint y = gl_GlobalInvocationID.y;
while (y < p.OH) {
uint z = gl_GlobalInvocationID.z;
while (z < p.batch_IC) {
im2col(y, z);
z += gl_NumWorkGroups.z;
}
y += gl_NumWorkGroups.y;
}
}