vulkan: remove PV matrix, helps with register usage

This commit is contained in:
Jeff Bolz 2025-05-07 13:46:35 -05:00
parent 876e6617a7
commit a6c940bb79

View file

@ -276,27 +276,22 @@ void main() {
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
}
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
vec4 PVf[Br][D_per_thread / 4];
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
PVf[r][d] = vec4(0.0);
Of[r][d] = eMf[r] * Of[r][d];
}
}
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
PVf[r][d] += Pf[r][c] * Vf;
Of[r][d] += Pf[r][c] * Vf;
}
}
}
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = eMf[r] * Of[r][d] + PVf[r][d];
}
}
barrier();
}