ggml : fix fallback to CPU for ununsupported ops (#15118)

This commit is contained in:
Diego Devesa 2025-08-06 05:37:35 -07:00 committed by GitHub
parent 65c797c4fa
commit 0d8831543c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 27 additions and 25 deletions

View file

@ -35,7 +35,7 @@
// ggml-backend interface
std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type() {
std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffer_types() {
static std::vector<ggml_backend_buffer_type_t> bufts = []() {
std::vector<ggml_backend_buffer_type_t> bufts;
@ -57,8 +57,6 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
}
#endif
bufts.push_back(NULL);
return bufts;
}();
@ -66,14 +64,20 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
}
static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_type(ggml_backend_dev_t device) {
return ggml_backend_cpu_get_extra_buffers_type().data();
static std::vector<ggml_backend_buffer_type_t> extra_bufts = [] {
std::vector<ggml_backend_buffer_type_t> bufts = ggml_backend_cpu_get_extra_buffer_types();
bufts.push_back(nullptr);
return bufts;
}();
return extra_bufts.data();
GGML_UNUSED(device);
}
static bool ggml_backend_cpu_is_extra_buffer_type(ggml_backend_buffer_type_t buft) {
for (auto * extra : ggml_backend_cpu_get_extra_buffers_type()) {
if (extra && extra == buft) {
for (auto * extra : ggml_backend_cpu_get_extra_buffer_types()) {
if (extra == buft) {
return true;
}
}
@ -397,20 +401,13 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
return true;
}
// extra_buffer_op?
for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
if (extra) {
auto buf_extra = (ggml::cpu::extra_buffer_type*) extra->context;
if (buf_extra && buf_extra->supports_op(dev, op)) {
return true;
}
}
}
// the other case need host buffer.
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (op->src[i] && op->src[i]->buffer && !ggml_backend_buft_is_host(op->src[i]->buffer->buft)) {
return false;
// check extra buffer types
// note: only the first sources are checked for extra buffer types to reduce overhead, increase if necessary
for (int i = 0; i < 4; i++) {
if (op->src[i] && op->src[i]->buffer &&
ggml_backend_cpu_is_extra_buffer_type(op->src[i]->buffer->buft)) {
auto * buf_extra = (ggml::cpu::extra_buffer_type *) op->src[i]->buffer->buft->context;
return buf_extra->supports_op(dev, op);
}
}