mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-08 09:59:50 +00:00
ggml: backend-agnostic tensor parallelism (experimental) (#19378)
* ggml: backend-agnostic tensor parallelism * support for GPT-OSS, Qwen 3 MoE * partial Vulkan fix * add support for 4/8 GPUs * unconditional peer access * re-use buffers + ggml contexts * fix output pattern * NCCL support * GGML: HIP: add RCCL support * Remove shfl and AllReduce from backend interface * move allocation workaround out of ggml-alloc.c * 2d tensor set/get support * Fix the seg fault without NCCL * Apply suggestion from JohannesGaessler * support for tensor dims % n_devs != 0 * fix view_offs scaling * arbitrary num. of GPUs/tensor split * fix compilation * better granularity estimate * Support device-specific host buffer types if all underlying backends expose the same type. This allows using pinned memory instead of pageable memory for CUDA. Fix compilation errors. * partial Qwen 3 Next support * Fix qwen3 30b (#8) * Fix crash with Qwen-30B-A3B Q4_0 Qwen-30B-A3B Q4_0 has an intermediate dimension of 768. Using a granularity of 256 forces an uneven split between GPUs, which is not supported by the current implementation. * Decide block size based on tensor quantization type * Fix crashes due to KV cache serialization (#9) KV cache serialization requires non-zero offsets on the tensor. Add support in the meta backend to set/get a tensor with a non-zero offset. * metal : fix build (#7) * static memory allocations, fix usage count * fix tensor granularity * more even memory distribution * use BF16 for allreduce * rebase fixup * better error message for unsupported architectures * Fix device mismatch during scatter of allReduce. (#11) There is a mismatch between the dst buffer device and the backend device, causing the use of sync copies * Enable the previous allreduce implementation. It is better in both perf and stability (#12) * delay AllReduce for Moe for less I/O * build : clean-up compile warnings * backend : move most of the meta backend API to ggml-backend-impl.h * cont : hide unused public API in the implementation * llama : use llama_device + remove ggml_backend_dev_is_meta() * ggml-backend : remove unused alloc include * minor : remove regex include * ggml : introduce ggml-ext.h for staging new APIs * rebase fixup * fix tests * llama : more robust logic for determining Meta devices (#16) * llama : more robust logic for determining Meta devices * cont : fix devs size check Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * cont : fix log type Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * disable roundtrip for meta backend * fix arch selection * Qwen 3.5 support * fix Gemma 4 MoE * fix OpenVino, SYCL * fix test-llama-archs for CPU-only builds * Fix Qwen 3.5 MoE * disable meta backend tests for WebGPU * tests : filter CPU-based devices from the Meta backend tests (#17) * meta : formatting, naming, indentation (#18) * formatting : llama-model.cpp * formatting : ggml-ext.h * formatting : ggml-backend-meta.cpp * meta : add TODO * add documentation * better error messages * fix GPT-OSS --------- Co-authored-by: Carl Philipp Klemm <carl@uvos.xyz> Co-authored-by: Gaurav Garg <gaugarg@nvidia.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
009a113326
commit
d6f3030047
48 changed files with 3198 additions and 342 deletions
|
|
@ -1,5 +1,6 @@
|
|||
#include "llama-context.h"
|
||||
|
||||
#include "ggml.h"
|
||||
#include "llama-arch.h"
|
||||
#include "llama-impl.h"
|
||||
#include "llama-batch.h"
|
||||
|
|
@ -8,6 +9,7 @@
|
|||
#include "llama-mmap.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-ext.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <cinttypes>
|
||||
#include <cmath>
|
||||
|
|
@ -217,10 +219,10 @@ llama_context::llama_context(
|
|||
|
||||
if (!hparams.vocab_only) {
|
||||
// GPU backends
|
||||
for (auto * dev : model.devices) {
|
||||
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
|
||||
for (const auto & dev : model.devices) {
|
||||
ggml_backend_t backend = ggml_backend_dev_init(dev.dev, nullptr);
|
||||
if (backend == nullptr) {
|
||||
throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev)));
|
||||
throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev.dev)));
|
||||
}
|
||||
backends.emplace_back(backend);
|
||||
}
|
||||
|
|
@ -295,8 +297,8 @@ llama_context::llama_context(
|
|||
|
||||
if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) {
|
||||
// use the host buffer of the first device CPU for faster transfer of the intermediate state
|
||||
auto * dev = model.devices[0];
|
||||
auto * host_buft = ggml_backend_dev_host_buffer_type(dev);
|
||||
const auto & dev = model.devices[0];
|
||||
auto * host_buft = ggml_backend_dev_host_buffer_type(dev.dev);
|
||||
if (host_buft) {
|
||||
buft = host_buft;
|
||||
}
|
||||
|
|
@ -1020,9 +1022,11 @@ void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void
|
|||
|
||||
for (auto & backend : backends) {
|
||||
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
|
||||
auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
|
||||
if (set_abort_callback_fn) {
|
||||
set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data);
|
||||
if (reg) {
|
||||
auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
|
||||
if (set_abort_callback_fn) {
|
||||
set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2942,6 +2946,21 @@ llama_context * llama_init_from_model(
|
|||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
||||
}
|
||||
|
||||
if (model->split_mode() == LLAMA_SPLIT_MODE_TENSOR) {
|
||||
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
|
||||
LLAMA_LOG_INFO("%s: enabling flash_attn since it is required for SPLIT_MODE_TENSOR\n", __func__);
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
|
||||
}
|
||||
if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_ENABLED) {
|
||||
LLAMA_LOG_ERROR("%s: SPLIT_MODE_TENSOR requires flash_attn to be enabled\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
if (ggml_is_quantized(params.type_k) || ggml_is_quantized(params.type_v)) {
|
||||
LLAMA_LOG_ERROR("%s: simultaneous use of SPLIT_MODE_TENSOR and KV cache quantization not implemented\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_k)) {
|
||||
const uint32_t blck_size = ggml_blck_size(params.type_k);
|
||||
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
|
||||
|
|
@ -3475,7 +3494,7 @@ void llama_perf_context_reset(llama_context * ctx) {
|
|||
}
|
||||
|
||||
void llama_memory_breakdown_print(const struct llama_context * ctx) {
|
||||
const std::vector<ggml_backend_dev_t> & devices = ctx->get_model().devices;
|
||||
const auto & devices = ctx->get_model().devices;
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown = ctx->memory_breakdown();
|
||||
|
||||
|
|
@ -3511,7 +3530,7 @@ void llama_memory_breakdown_print(const struct llama_context * ctx) {
|
|||
if (dev) {
|
||||
int i_dev = -1;
|
||||
for (size_t i = 0; i < devices.size(); i++) {
|
||||
if (devices[i] == dev) {
|
||||
if (devices[i].dev == dev) {
|
||||
i_dev = i;
|
||||
break;
|
||||
}
|
||||
|
|
@ -3528,7 +3547,7 @@ void llama_memory_breakdown_print(const struct llama_context * ctx) {
|
|||
|
||||
// print memory breakdown for each device:
|
||||
for (size_t i = 0; i < devices.size(); i++) {
|
||||
ggml_backend_dev_t dev = devices[i];
|
||||
ggml_backend_dev_t dev = devices[i].dev;
|
||||
llama_memory_breakdown_data mb = mb_dev[i];
|
||||
|
||||
const std::string name = ggml_backend_dev_name(dev);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue