diff --git a/src/llama.cpp b/src/llama.cpp index b8da9822..5be02d68 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -18066,16 +18066,17 @@ static int llama_decode_internal( } // overlap memory scheduling with other nodes' communication and computing - if (cparams.unload) { + { timer(manage_graph_tensors); - if (n_world != 1) { - manage_graph_tensors(sub_gf, POSIX_MADV_DONTNEED); + + int next_gf_id = (i + 1) % gf.size(); + manage_graph_tensors(gf[next_gf_id], POSIX_MADV_WILLNEED, false); + if (my_rank == 0 && (is_last_l || (next_gf_id == (int)gf.size() - 1))) { + manage_graph_tensors(gf[0], POSIX_MADV_WILLNEED, false); + } - int next_gf_id = (i + 1) % gf.size(); - manage_graph_tensors(gf[next_gf_id], POSIX_MADV_WILLNEED, false); - if (my_rank == 0 && (is_last_l || (next_gf_id == (int)gf.size() - 1))) { - manage_graph_tensors(gf[0], POSIX_MADV_WILLNEED, false); - } + if (cparams.unload && n_world > 1) { + manage_graph_tensors(sub_gf, POSIX_MADV_DONTNEED); } } }