metal : optimize pad + cpy (#23354)

* metal : optimize pad

* metal : optinmize cpy

* cont : better row packing in threadgroup
This commit is contained in:
Georgi Gerganov 2026-05-20 09:42:00 +03:00 committed by GitHub
parent 871b0b70f8
commit 57ebaf4edd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 94 additions and 67 deletions

View file

@ -562,13 +562,13 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn(
}
const int64_t D = S_v * S_v * H_v;
const int64_t K = (int64_t) cparams.n_rs_seq + 1;
const int64_t K = cparams.n_rs_seq + 1;
// TODO: remove pad + simplify
ggml_tensor * state_in_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs);
ggml_tensor * state_3d = ggml_pad(ctx0, state_in_3d, 0, K - 1, 0, 0);
ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs);
ggml_tensor * s_3d_pad = ggml_pad (ctx0, s_3d, 0, K - 1, 0, 0);
ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, state_3d);
ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d_pad);
if (n_seq_tokens > 1) {
cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il);
} else {