From c4c6a642fc7feffdcc956624de4b91ff52a10276 Mon Sep 17 00:00:00 2001 From: Zonghang Li Date: Sat, 8 Feb 2025 22:44:38 +0400 Subject: [PATCH] manage_graph_tensors: fix segment prefetch --- src/llama.cpp | 56 +++++++++++++++++++++++++++++++++++---------------- 1 file changed, 39 insertions(+), 17 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index e770aca8..30226968 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17766,35 +17766,57 @@ static float is_graph_loaded(struct ggml_cgraph * cgraph) { } static void manage_graph_tensors(struct ggml_cgraph * cgraph, int advice, bool force = false) { - size_t first = SIZE_MAX; - size_t last = 0; long page_size = sysconf(_SC_PAGESIZE); + struct Segment { + size_t start; + size_t end; + }; + std::vector segments; + for (int i = 0; i < ggml_graph_n_leafs(cgraph); i++) { struct ggml_tensor * cur = ggml_graph_leaf(cgraph, i); - if (strstr(cur->name, "weight") == nullptr || cur->data == nullptr) { continue; } - size_t addr = reinterpret_cast(cur->data); - first = std::min(first, addr); - last = std::max(last, addr + ggml_nbytes(cur)); + size_t size = ggml_nbytes(cur); + size_t first = reinterpret_cast(cur->data); + size_t last = first + size; + + first = first - (first % page_size); + if (last % page_size != 0) { + last = last + (page_size - (last % page_size)); + } + segments.push_back({first, last}); } - // align addr - llama_mmap::align_range(&first, &last, page_size); - size_t len = std::max(last - first, static_cast(page_size)); + if (segments.empty()) return; - // hint to load memory - posix_madvise(reinterpret_cast(first), len, advice); + std::sort(segments.begin(), segments.end(), [](const Segment & a, const Segment & b) { + return a.start < b.start; + }); - // if advice is POSIX_MADV_WILLNEED, force to prefetch data - if (force && advice == POSIX_MADV_WILLNEED) { - // coarse-grained prefetch - volatile char * ptr = (volatile char *)first; - for (size_t off = 0; off < len; off += page_size) { - (void)ptr[off]; + std::vector merged_segments; + merged_segments.push_back(segments[0]); + for (size_t i = 1; i < segments.size(); i++) { + Segment & last = merged_segments.back(); + if (segments[i].start <= last.end) { + last.end = std::max(last.end, segments[i].end); + } else { + merged_segments.push_back(segments[i]); + } + } + + for (const auto & segment : merged_segments) { + size_t len = std::max(segment.end - segment.start, static_cast(page_size)); + posix_madvise(reinterpret_cast(segment.start), len, advice); // hint to load into memory + // force to prefetch data + if (force && advice == POSIX_MADV_WILLNEED) { + volatile char * ptr = reinterpret_cast(segment.start); + for (size_t off = 0; off < len; off += page_size) { + (void)ptr[off]; + } } } }