mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-22 11:16:08 +00:00
vulkan: optimize operations in the IM2COL shader (#22685)
* vulkan: optimize operations in the IM2COL shader * Add comments and improve the code formatting
This commit is contained in:
parent
6ce96713de
commit
acd604fb27
1 changed files with 59 additions and 14 deletions
|
|
@ -44,36 +44,81 @@ void im2col(const uint ow, const uint z_idx) {
|
|||
|
||||
const uint KHKW = p.KH * p.KW;
|
||||
|
||||
// Precompute base input coordinates
|
||||
const int base_iw = int(ow * p.s0) - p.p0;
|
||||
const int base_ih = int(oh * p.s1) - p.p1;
|
||||
|
||||
// Precompute step deltas
|
||||
const uint delta_ic = BLOCK_SIZE / KHKW;
|
||||
const uint delta_rem = BLOCK_SIZE % KHKW;
|
||||
|
||||
const uint delta_ky = delta_rem / p.KW;
|
||||
const uint delta_kx = delta_rem % p.KW;
|
||||
|
||||
const uint delta_ic_offset = delta_ic * p.offset_delta;
|
||||
|
||||
// If using BDA mode, precompute the base pointer and step size
|
||||
#if BDA
|
||||
const BDA_STORAGE_T base_dst_addr = p.dst_addr + D_SIZE * dst_row;
|
||||
const uint bda_step = D_SIZE * BLOCK_SIZE;
|
||||
#endif
|
||||
|
||||
uint wg_x = gl_WorkGroupID.x;
|
||||
do {
|
||||
const uint wg_offset = wg_x * 512;
|
||||
|
||||
[[unroll]] for (uint i = 0; i < NUM_ITER; ++i) {
|
||||
const uint chw_idx = wg_offset + gidx + i * BLOCK_SIZE;
|
||||
uint chw_idx = wg_offset + gidx;
|
||||
|
||||
uint ic = chw_idx / KHKW;
|
||||
uint rem = chw_idx % KHKW;
|
||||
|
||||
uint ky = rem / p.KW;
|
||||
uint kx = rem % p.KW;
|
||||
|
||||
uint ic_offset = src_batch + ic * p.offset_delta;
|
||||
|
||||
// Initialize running pointer/index for the destination buffer
|
||||
#if BDA
|
||||
BDA_STORAGE_T current_dst_addr = base_dst_addr + D_SIZE * chw_idx;
|
||||
#else
|
||||
uint current_dst_idx = dst_row + chw_idx;
|
||||
#endif
|
||||
|
||||
[[unroll]] for (uint i = 0; i < NUM_ITER; ++i) {
|
||||
if (chw_idx >= p.CHW) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint ic = chw_idx / KHKW;
|
||||
const uint rem = chw_idx - ic * KHKW;
|
||||
const uint ky = rem / p.KW;
|
||||
const uint kx = rem - ky * p.KW;
|
||||
|
||||
const uint iiw = ow * p.s0 + kx * p.d0 - p.p0;
|
||||
const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
|
||||
const int iiw = base_iw + int(kx * p.d0);
|
||||
const int iih = base_ih + int(ky * p.d1);
|
||||
|
||||
A_TYPE val = A_TYPE(0);
|
||||
if (iih < p.IH && iiw < p.IW) {
|
||||
val = data_a[src_batch + ic * p.offset_delta + iih * p.IW + iiw];
|
||||
if (uint(iih) < p.IH && uint(iiw) < p.IW) {
|
||||
val = data_a[ic_offset + uint(iih) * p.IW + uint(iiw)];
|
||||
}
|
||||
|
||||
#if BDA
|
||||
D_ptr out_ptr = D_ptr(p.dst_addr + D_SIZE * (dst_row + chw_idx));
|
||||
out_ptr.d = D_TYPE(val);
|
||||
D_ptr(current_dst_addr).d = D_TYPE(val);
|
||||
current_dst_addr += bda_step;
|
||||
#else
|
||||
data_d[dst_row + chw_idx] = D_TYPE(val);
|
||||
data_d[current_dst_idx] = D_TYPE(val);
|
||||
current_dst_idx += BLOCK_SIZE;
|
||||
#endif
|
||||
|
||||
chw_idx += BLOCK_SIZE;
|
||||
ic_offset += delta_ic_offset;
|
||||
kx += delta_kx;
|
||||
ky += delta_ky;
|
||||
|
||||
// Handle X axis wrap
|
||||
uint kx_wrap = uint(kx >= p.KW);
|
||||
kx -= kx_wrap * p.KW;
|
||||
ky += kx_wrap;
|
||||
|
||||
// Handle Y axis wrap
|
||||
uint ky_wrap = uint(ky >= p.KH);
|
||||
ky -= ky_wrap * p.KH;
|
||||
ic_offset += ky_wrap * p.offset_delta;
|
||||
}
|
||||
|
||||
wg_x += gl_NumWorkGroups.x;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue