mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-08 01:39:03 +00:00
set backend_embd and out_embd on the device of their adjacent nodes
This commit is contained in:
parent
bfc3f9e185
commit
7aa771e701
2 changed files with 63 additions and 18 deletions
|
@ -88,6 +88,7 @@
|
|||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#include <chrono>
|
||||
#include <regex>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
|
@ -4478,6 +4479,18 @@ static size_t llama_model_max_nodes(const llama_model & model) {
|
|||
return std::max<size_t>(8192, model.tensors_by_name.size()*5);
|
||||
}
|
||||
|
||||
static int get_layer_id(const ggml_tensor * tensor) {
|
||||
std::string name(ggml_get_name(tensor));
|
||||
std::regex layer_id_regex(R"(\.([0-9]+)\.)");
|
||||
std::smatch match;
|
||||
|
||||
if (std::regex_search(name, match, layer_id_regex)) {
|
||||
return std::stoi(match[1].str());
|
||||
} else {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
struct llama_model_loader {
|
||||
int n_kv = 0;
|
||||
int n_tensors = 0;
|
||||
|
@ -5085,13 +5098,9 @@ struct llama_model_loader {
|
|||
*addr = mapping->addr;
|
||||
for (ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor; tensor = ggml_get_next_tensor(ctx, tensor)) {
|
||||
try {
|
||||
const auto * weight = get_weight(ggml_get_name(tensor));
|
||||
if (!weight) {
|
||||
continue;
|
||||
}
|
||||
if (weight->idx != idx) {
|
||||
continue;
|
||||
}
|
||||
const llama_tensor_weight * weight = get_weight(ggml_get_name(tensor));
|
||||
if (!weight || weight->idx != idx) continue;
|
||||
|
||||
*first = std::min(*first, weight->offs);
|
||||
*last = std::max(*last, weight->offs + ggml_nbytes(tensor));
|
||||
} catch(...) {
|
||||
|
@ -5100,6 +5109,42 @@ struct llama_model_loader {
|
|||
}
|
||||
}
|
||||
|
||||
void get_mapping_ranges(std::vector<std::pair<size_t, size_t>>& buffer_ranges, void ** addr, int idx, ggml_context * ctx) const {
|
||||
GGML_ASSERT(!mappings.empty());
|
||||
const auto & mapping = mappings.at(idx);
|
||||
*addr = mapping->addr;
|
||||
|
||||
for (ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor; tensor = ggml_get_next_tensor(ctx, tensor)) {
|
||||
try {
|
||||
const llama_tensor_weight * weight = get_weight(ggml_get_name(tensor));
|
||||
if (!weight || weight->idx != idx) continue;
|
||||
|
||||
size_t tensor_first = weight->offs;
|
||||
size_t tensor_last = tensor_first + ggml_nbytes(tensor);
|
||||
|
||||
auto it = std::lower_bound(
|
||||
buffer_ranges.begin(), buffer_ranges.end(), std::make_pair(tensor_first, tensor_last),
|
||||
[](const std::pair<size_t, size_t>& a, const std::pair<size_t, size_t>& b) {
|
||||
return a.first < b.first;
|
||||
}
|
||||
);
|
||||
|
||||
if (it != buffer_ranges.begin() && (it - 1)->second >= tensor_first) {
|
||||
--it;
|
||||
it->second = std::max(it->second, tensor_last);
|
||||
} else {
|
||||
it = buffer_ranges.insert(it, {tensor_first, tensor_last});
|
||||
}
|
||||
|
||||
while (it + 1 != buffer_ranges.end() && (it + 1)->first <= it->second) {
|
||||
it->second = std::max(it->second, (it + 1)->second);
|
||||
buffer_ranges.erase(it + 1);
|
||||
}
|
||||
} catch (...) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for backwards compatibility, does not support ggml-backend
|
||||
void load_data_for(struct ggml_tensor * cur) const {
|
||||
const auto & w = require_weight(ggml_get_name(cur));
|
||||
|
@ -10322,7 +10367,7 @@ struct llm_build_context {
|
|||
const llm_build_cb & cb) {
|
||||
lctx.backend_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_embd, batch.n_tokens);
|
||||
cb(lctx.backend_embd, "backend_embd", -1);
|
||||
ggml_set_input(lctx.backend_embd);
|
||||
// ggml_set_input(lctx.backend_embd); // set it on the device of the adjacent node
|
||||
return lctx.backend_embd;
|
||||
}
|
||||
|
||||
|
@ -10333,7 +10378,7 @@ struct llm_build_context {
|
|||
const llm_build_cb & cb) {
|
||||
lctx.out_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_embd, n_outputs);
|
||||
cb(lctx.out_embd, "out_embd", -1);
|
||||
ggml_set_input(lctx.out_embd);
|
||||
// ggml_set_input(lctx.out_embd); // set it on the device of the adjacent node
|
||||
return lctx.out_embd;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue