sched : fix possible use of wrong ids tensor when offloading moe prompt processing (#15488)

This commit is contained in:
Diego Devesa 2025-08-21 14:09:32 -07:00 committed by GitHub
parent cd36b5e5c7
commit 54a241f505
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 21 additions and 8 deletions

View file

@ -1755,7 +1755,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) { [](common_params & params) {
params.warmup = false; params.warmup = false;
} }
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL})); ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY}));
add_opt(common_arg( add_opt(common_arg(
{"--spm-infill"}, {"--spm-infill"},
string_format( string_format(

View file

@ -1355,15 +1355,15 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
std::vector<int32_t> ids; std::vector<int32_t> ids;
std::vector<ggml_bitset_t> used_ids; std::vector<ggml_bitset_t> used_ids;
for (int i = 0; i < sched->n_splits; i++) { for (int split_id = 0; split_id < sched->n_splits; split_id++) {
struct ggml_backend_sched_split * split = &splits[i]; struct ggml_backend_sched_split * split = &splits[split_id];
int split_backend_id = split->backend_id; int split_backend_id = split->backend_id;
ggml_backend_t split_backend = sched->backends[split_backend_id]; ggml_backend_t split_backend = sched->backends[split_backend_id];
// copy the input tensors to the split backend // copy the input tensors to the split backend
for (int j = 0; j < split->n_inputs; j++) { for (int input_id = 0; input_id < split->n_inputs; input_id++) {
ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[j]); ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[input_id]);
struct ggml_tensor * input = split->inputs[j]; struct ggml_tensor * input = split->inputs[input_id];
struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy); struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy);
if (input->flags & GGML_TENSOR_FLAG_INPUT) { if (input->flags & GGML_TENSOR_FLAG_INPUT) {
@ -1398,10 +1398,22 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
// get the ids // get the ids
ggml_tensor * ids_tensor = node->src[2]; ggml_tensor * ids_tensor = node->src[2];
ggml_backend_t ids_backend = split_backend;
// if the ids tensor is also an input of the split, it may not have been copied yet to the split backend
// in that case, we use the original ids tensor
for (int i = input_id + 1; i < split->n_inputs; i++) {
if (ids_tensor == tensor_copy(split->inputs[i], split_backend_id, sched->cur_copy)) {
ids_tensor = split->inputs[i];
ids_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[i]);
break;
}
}
if (ids_tensor != prev_ids_tensor) { if (ids_tensor != prev_ids_tensor) {
ids.resize(ggml_nbytes(ids_tensor) / sizeof(int32_t)); ids.resize(ggml_nbytes(ids_tensor) / sizeof(int32_t));
ggml_backend_tensor_get_async(split_backend, ids_tensor, ids.data(), 0, ggml_nbytes(ids_tensor)); ggml_backend_tensor_get_async(ids_backend, ids_tensor, ids.data(), 0, ggml_nbytes(ids_tensor));
ggml_backend_synchronize(split_backend); ggml_backend_synchronize(ids_backend);
// find the used experts // find the used experts
used_ids.clear(); used_ids.clear();
@ -1409,6 +1421,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
for (int64_t i1 = 0; i1 < ids_tensor->ne[1]; i1++) { for (int64_t i1 = 0; i1 < ids_tensor->ne[1]; i1++) {
for (int64_t i0 = 0; i0 < ids_tensor->ne[0]; i0++) { for (int64_t i0 = 0; i0 < ids_tensor->ne[0]; i0++) {
int32_t id = ids[i1 * ids_tensor->nb[1]/sizeof(int32_t) + i0 * ids_tensor->nb[0]/sizeof(int32_t)]; int32_t id = ids[i1 * ids_tensor->nb[1]/sizeof(int32_t) + i0 * ids_tensor->nb[0]/sizeof(int32_t)];
GGML_ASSERT(id >= 0 && id < n_expert);
ggml_bitset_set(used_ids.data(), id); ggml_bitset_set(used_ids.data(), id);
} }
} }