diff --git a/src/llama.cpp b/src/llama.cpp index 1020277c..d316e400 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3052,8 +3052,7 @@ struct llama_sbatch { ubatch_token.resize(!has_embd ? n_ubatch : 0); ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0); - // TODO: just a guess and test, need to be removed(from tao) - ubatch_backend_embd.resize(n_embd * n_tokens * 3); + ubatch_backend_embd.resize(n_embd * n_tokens + n_tokens); ubatch_out_embd.resize(n_embd * n_tokens); ubatch_pos.resize(n_ubatch);