mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-07 02:19:03 +00:00
manage_graph_tensors: fix segment prefetch
This commit is contained in:
parent
d2bc5cd502
commit
c4c6a642fc
1 changed files with 39 additions and 17 deletions
|
@ -17766,35 +17766,57 @@ static float is_graph_loaded(struct ggml_cgraph * cgraph) {
|
|||
}
|
||||
|
||||
static void manage_graph_tensors(struct ggml_cgraph * cgraph, int advice, bool force = false) {
|
||||
size_t first = SIZE_MAX;
|
||||
size_t last = 0;
|
||||
long page_size = sysconf(_SC_PAGESIZE);
|
||||
|
||||
struct Segment {
|
||||
size_t start;
|
||||
size_t end;
|
||||
};
|
||||
std::vector<Segment> segments;
|
||||
|
||||
for (int i = 0; i < ggml_graph_n_leafs(cgraph); i++) {
|
||||
struct ggml_tensor * cur = ggml_graph_leaf(cgraph, i);
|
||||
|
||||
if (strstr(cur->name, "weight") == nullptr || cur->data == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t addr = reinterpret_cast<size_t>(cur->data);
|
||||
first = std::min(first, addr);
|
||||
last = std::max(last, addr + ggml_nbytes(cur));
|
||||
size_t size = ggml_nbytes(cur);
|
||||
size_t first = reinterpret_cast<size_t>(cur->data);
|
||||
size_t last = first + size;
|
||||
|
||||
first = first - (first % page_size);
|
||||
if (last % page_size != 0) {
|
||||
last = last + (page_size - (last % page_size));
|
||||
}
|
||||
segments.push_back({first, last});
|
||||
}
|
||||
|
||||
// align addr
|
||||
llama_mmap::align_range(&first, &last, page_size);
|
||||
size_t len = std::max(last - first, static_cast<size_t>(page_size));
|
||||
if (segments.empty()) return;
|
||||
|
||||
// hint to load memory
|
||||
posix_madvise(reinterpret_cast<void *>(first), len, advice);
|
||||
std::sort(segments.begin(), segments.end(), [](const Segment & a, const Segment & b) {
|
||||
return a.start < b.start;
|
||||
});
|
||||
|
||||
// if advice is POSIX_MADV_WILLNEED, force to prefetch data
|
||||
if (force && advice == POSIX_MADV_WILLNEED) {
|
||||
// coarse-grained prefetch
|
||||
volatile char * ptr = (volatile char *)first;
|
||||
for (size_t off = 0; off < len; off += page_size) {
|
||||
(void)ptr[off];
|
||||
std::vector<Segment> merged_segments;
|
||||
merged_segments.push_back(segments[0]);
|
||||
for (size_t i = 1; i < segments.size(); i++) {
|
||||
Segment & last = merged_segments.back();
|
||||
if (segments[i].start <= last.end) {
|
||||
last.end = std::max(last.end, segments[i].end);
|
||||
} else {
|
||||
merged_segments.push_back(segments[i]);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto & segment : merged_segments) {
|
||||
size_t len = std::max(segment.end - segment.start, static_cast<size_t>(page_size));
|
||||
posix_madvise(reinterpret_cast<void *>(segment.start), len, advice); // hint to load into memory
|
||||
// force to prefetch data
|
||||
if (force && advice == POSIX_MADV_WILLNEED) {
|
||||
volatile char * ptr = reinterpret_cast<volatile char *>(segment.start);
|
||||
for (size_t off = 0; off < len; off += page_size) {
|
||||
(void)ptr[off];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue