ggml: backend-agnostic tensor parallelism (experimental) (#19378)

* ggml: backend-agnostic tensor parallelism

* support for GPT-OSS, Qwen 3 MoE

* partial Vulkan fix

* add support for 4/8 GPUs

* unconditional peer access

* re-use buffers + ggml contexts

* fix output pattern

* NCCL support

* GGML: HIP: add RCCL support

* Remove shfl and AllReduce from backend interface

* move allocation workaround out of ggml-alloc.c

* 2d tensor set/get support

* Fix the seg fault without NCCL

* Apply suggestion from JohannesGaessler

* support for tensor dims % n_devs != 0

* fix view_offs scaling

* arbitrary num. of GPUs/tensor split

* fix compilation

* better granularity estimate

* Support device-specific host buffer types if all underlying backends expose the same type. This allows using pinned memory instead of pageable memory for CUDA.

Fix compilation errors.

* partial Qwen 3 Next support

* Fix qwen3 30b (#8)

* Fix crash with Qwen-30B-A3B Q4_0

Qwen-30B-A3B Q4_0 has an intermediate dimension of 768. Using a granularity of 256 forces an uneven split between GPUs, which is not supported by the current implementation.

* Decide block size based on tensor quantization type

* Fix crashes due to KV cache serialization (#9)

KV cache serialization requires non-zero offsets on the tensor. Add support in the meta backend to set/get a tensor with a non-zero offset.

* metal : fix build (#7)

* static memory allocations, fix usage count

* fix tensor granularity

* more even memory distribution

* use BF16 for allreduce

* rebase fixup

* better error message for unsupported architectures

* Fix device mismatch during scatter of allReduce. (#11)

There is a mismatch between the dst buffer device and the backend device, causing the use of sync copies

* Enable the previous allreduce implementation. It is better in both perf and stability (#12)

* delay AllReduce for Moe for less I/O

* build : clean-up compile warnings

* backend : move most of the meta backend API to ggml-backend-impl.h

* cont : hide unused public API in the implementation

* llama : use llama_device + remove ggml_backend_dev_is_meta()

* ggml-backend : remove unused alloc include

* minor : remove regex include

* ggml : introduce ggml-ext.h for staging new APIs

* rebase fixup

* fix tests

* llama : more robust logic for determining Meta devices (#16)

* llama : more robust logic for determining Meta devices

* cont : fix devs size check

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* cont : fix log type

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* disable roundtrip for meta backend

* fix arch selection

* Qwen 3.5 support

* fix Gemma 4 MoE

* fix OpenVino, SYCL

* fix test-llama-archs for CPU-only builds

* Fix Qwen 3.5 MoE

* disable meta backend tests for WebGPU

* tests : filter CPU-based devices from the Meta backend tests (#17)

* meta : formatting, naming, indentation (#18)

* formatting : llama-model.cpp

* formatting : ggml-ext.h

* formatting : ggml-backend-meta.cpp

* meta : add TODO

* add documentation

* better error messages

* fix GPT-OSS

---------

Co-authored-by: Carl Philipp Klemm <carl@uvos.xyz>
Co-authored-by: Gaurav Garg <gaugarg@nvidia.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Johannes Gäßler 2026-04-09 16:42:19 +02:00 committed by GitHub
parent 009a113326
commit d6f3030047
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
48 changed files with 3198 additions and 342 deletions

View file

@ -1,5 +1,6 @@
#include "llama-context.h"
#include "ggml.h"
#include "llama-arch.h"
#include "llama-impl.h"
#include "llama-batch.h"
@ -8,6 +9,7 @@
#include "llama-mmap.h"
#include "llama-model.h"
#include "llama-ext.h"
#include "llama.h"
#include <cinttypes>
#include <cmath>
@ -217,10 +219,10 @@ llama_context::llama_context(
if (!hparams.vocab_only) {
// GPU backends
for (auto * dev : model.devices) {
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
for (const auto & dev : model.devices) {
ggml_backend_t backend = ggml_backend_dev_init(dev.dev, nullptr);
if (backend == nullptr) {
throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev)));
throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev.dev)));
}
backends.emplace_back(backend);
}
@ -295,8 +297,8 @@ llama_context::llama_context(
if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) {
// use the host buffer of the first device CPU for faster transfer of the intermediate state
auto * dev = model.devices[0];
auto * host_buft = ggml_backend_dev_host_buffer_type(dev);
const auto & dev = model.devices[0];
auto * host_buft = ggml_backend_dev_host_buffer_type(dev.dev);
if (host_buft) {
buft = host_buft;
}
@ -1020,9 +1022,11 @@ void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void
for (auto & backend : backends) {
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
if (set_abort_callback_fn) {
set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data);
if (reg) {
auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
if (set_abort_callback_fn) {
set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data);
}
}
}
}
@ -2942,6 +2946,21 @@ llama_context * llama_init_from_model(
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
}
if (model->split_mode() == LLAMA_SPLIT_MODE_TENSOR) {
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
LLAMA_LOG_INFO("%s: enabling flash_attn since it is required for SPLIT_MODE_TENSOR\n", __func__);
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
}
if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_ENABLED) {
LLAMA_LOG_ERROR("%s: SPLIT_MODE_TENSOR requires flash_attn to be enabled\n", __func__);
return nullptr;
}
if (ggml_is_quantized(params.type_k) || ggml_is_quantized(params.type_v)) {
LLAMA_LOG_ERROR("%s: simultaneous use of SPLIT_MODE_TENSOR and KV cache quantization not implemented\n", __func__);
return nullptr;
}
}
if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_k)) {
const uint32_t blck_size = ggml_blck_size(params.type_k);
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
@ -3475,7 +3494,7 @@ void llama_perf_context_reset(llama_context * ctx) {
}
void llama_memory_breakdown_print(const struct llama_context * ctx) {
const std::vector<ggml_backend_dev_t> & devices = ctx->get_model().devices;
const auto & devices = ctx->get_model().devices;
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown = ctx->memory_breakdown();
@ -3511,7 +3530,7 @@ void llama_memory_breakdown_print(const struct llama_context * ctx) {
if (dev) {
int i_dev = -1;
for (size_t i = 0; i < devices.size(); i++) {
if (devices[i] == dev) {
if (devices[i].dev == dev) {
i_dev = i;
break;
}
@ -3528,7 +3547,7 @@ void llama_memory_breakdown_print(const struct llama_context * ctx) {
// print memory breakdown for each device:
for (size_t i = 0; i < devices.size(); i++) {
ggml_backend_dev_t dev = devices[i];
ggml_backend_dev_t dev = devices[i].dev;
llama_memory_breakdown_data mb = mb_dev[i];
const std::string name = ggml_backend_dev_name(dev);