manage_graph_tensors: fix segment prefetch

This commit is contained in:
Zonghang Li 2025-02-08 22:44:38 +04:00
parent d2bc5cd502
commit c4c6a642fc

View file

@ -17766,38 +17766,60 @@ static float is_graph_loaded(struct ggml_cgraph * cgraph) {
}
static void manage_graph_tensors(struct ggml_cgraph * cgraph, int advice, bool force = false) {
size_t first = SIZE_MAX;
size_t last = 0;
long page_size = sysconf(_SC_PAGESIZE);
struct Segment {
size_t start;
size_t end;
};
std::vector<Segment> segments;
for (int i = 0; i < ggml_graph_n_leafs(cgraph); i++) {
struct ggml_tensor * cur = ggml_graph_leaf(cgraph, i);
if (strstr(cur->name, "weight") == nullptr || cur->data == nullptr) {
continue;
}
size_t addr = reinterpret_cast<size_t>(cur->data);
first = std::min(first, addr);
last = std::max(last, addr + ggml_nbytes(cur));
size_t size = ggml_nbytes(cur);
size_t first = reinterpret_cast<size_t>(cur->data);
size_t last = first + size;
first = first - (first % page_size);
if (last % page_size != 0) {
last = last + (page_size - (last % page_size));
}
segments.push_back({first, last});
}
// align addr
llama_mmap::align_range(&first, &last, page_size);
size_t len = std::max(last - first, static_cast<size_t>(page_size));
if (segments.empty()) return;
// hint to load memory
posix_madvise(reinterpret_cast<void *>(first), len, advice);
std::sort(segments.begin(), segments.end(), [](const Segment & a, const Segment & b) {
return a.start < b.start;
});
// if advice is POSIX_MADV_WILLNEED, force to prefetch data
std::vector<Segment> merged_segments;
merged_segments.push_back(segments[0]);
for (size_t i = 1; i < segments.size(); i++) {
Segment & last = merged_segments.back();
if (segments[i].start <= last.end) {
last.end = std::max(last.end, segments[i].end);
} else {
merged_segments.push_back(segments[i]);
}
}
for (const auto & segment : merged_segments) {
size_t len = std::max(segment.end - segment.start, static_cast<size_t>(page_size));
posix_madvise(reinterpret_cast<void *>(segment.start), len, advice); // hint to load into memory
// force to prefetch data
if (force && advice == POSIX_MADV_WILLNEED) {
// coarse-grained prefetch
volatile char * ptr = (volatile char *)first;
volatile char * ptr = reinterpret_cast<volatile char *>(segment.start);
for (size_t off = 0; off < len; off += page_size) {
(void)ptr[off];
}
}
}
}
// decode a batch of tokens by evaluating the transformer
//