set backend_embd and out_embd on the device of their adjacent nodes

This commit is contained in:
Lizonghang 2024-11-02 09:41:40 +04:00
parent bfc3f9e185
commit 7aa771e701
2 changed files with 63 additions and 18 deletions

View file

@ -1963,7 +1963,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
ggml_backend_sched_print_assignments(sched, graph);
}
// swap node_backend_ids and leaf _backend_ids with prevs
// swap node_backend_ids and leaf_backend_ids with prevs
{
int * tmp = sched->node_backend_ids;
sched->node_backend_ids = sched->prev_node_backend_ids;
@ -2205,14 +2205,14 @@ ggml_backend_sched_t ggml_backend_sched_new(
// initialize hash table
// FIXME: needs to be size*2 to account for leafs (do it in graph_split instead)
sched->hash_set = ggml_hash_set_new(graph_size);
sched->hash_set = ggml_hash_set_new(graph_size);
sched->hv_tensor_backend_ids = (int *) malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0]));
sched->hv_tensor_copies = (ggml_tensor **) malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *));
const size_t ggml_sched_max_splits = graph_size; // at most there is one split for each node in the graph
const size_t nodes_size = graph_size + ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2;
sched->node_backend_ids = (int *) calloc(nodes_size, sizeof(sched->node_backend_ids[0]));
sched->leaf_backend_ids = (int *) calloc(nodes_size, sizeof(sched->leaf_backend_ids[0]));
const size_t nodes_size = graph_size + ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2;
sched->node_backend_ids = (int *) calloc(nodes_size, sizeof(sched->node_backend_ids[0]));
sched->leaf_backend_ids = (int *) calloc(nodes_size, sizeof(sched->leaf_backend_ids[0]));
sched->prev_node_backend_ids = (int *) calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0]));
sched->prev_leaf_backend_ids = (int *) calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0]));
@ -2400,7 +2400,7 @@ void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor
(char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer));
tensor->buffer = buffer;
tensor->data = addr;
tensor->data = addr;
ggml_backend_buffer_init_tensor(buffer, tensor);
}