server: add margin for draft model for fit (#23485)

This commit is contained in:
Aman Gupta 2026-05-24 14:43:08 +08:00 committed by GitHub
parent fff63b5108
commit 83eebe9d08
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 95 additions and 2 deletions

View file

@ -26,7 +26,7 @@ class common_params_fit_exception : public std::runtime_error {
using std::runtime_error::runtime_error;
};
static std::vector<llama_device_memory_data> common_get_device_memory_data(
std::vector<llama_device_memory_data> common_get_device_memory_data(
const char * path_model,
const llama_model_params * mparams,
const llama_context_params * cparams,

View file

@ -1,6 +1,11 @@
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
#include "llama.h"
#include "../src/llama-ext.h"
#include <vector>
enum common_params_fit_status {
COMMON_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit
@ -30,3 +35,14 @@ void common_fit_print(
struct llama_context_params * cparams);
void common_memory_breakdown_print(const struct llama_context * ctx);
// Load a model + context with no_alloc and return the per-device memory breakdown.
std::vector<llama_device_memory_data> common_get_device_memory_data(
const char * path_model,
const struct llama_model_params * mparams,
const struct llama_context_params * cparams,
std::vector<ggml_backend_dev_t> & devs,
uint32_t & hp_ngl,
uint32_t & hp_n_ctx_train,
uint32_t & hp_n_expert,
enum ggml_log_level log_level);

View file

@ -8,6 +8,7 @@
#include "build-info.h"
#include "common.h"
#include "fit.h"
#include "llama.h"
#include "log.h"
#include "sampling.h"
@ -775,7 +776,7 @@ private:
for (auto & [dev, size] : mmproj_mem) {
total += size;
}
SRV_INF("[mtmd] estimated memory usage of mmproj is %.2f MiB\n", total / (1024.0 * 1024.0));
SRV_INF("[mtmd] estimated worst-case memory usage of mmproj is %.2f MiB\n", total / (1024.0 * 1024.0));
GGML_ASSERT(!params_base.fit_params_target.empty());
for (auto & [dev, size] : mmproj_mem) {
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
@ -793,6 +794,82 @@ private:
}
}
// optionally reserve VRAM for the draft / MTP context before fitting the target model
if (params_base.fit_params) {
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
params_base.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end();
const bool has_draft = params_base.speculative.has_dft();
if (has_draft || spec_mtp) {
common_params params_dft = params_base;
bool measure_model_bytes = true;
if (has_draft) {
const auto & params_spec = params_base.speculative.draft;
params_dft.devices = params_spec.devices;
params_dft.model = params_spec.mparams;
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
params_dft.cache_type_k = params_spec.cache_type_k;
params_dft.cache_type_v = params_spec.cache_type_v;
params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides;
} else {
// MTP draft context lives on the target model, only context+compute are new
measure_model_bytes = false;
}
auto mparams_dft = common_model_params_to_llama(params_dft);
auto cparams_dft = common_context_params_to_llama(params_dft);
if (spec_mtp) {
cparams_dft.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
}
cparams_dft.n_rs_seq = 0;
std::vector<ggml_backend_dev_t> devs;
uint32_t hp_ngl = 0;
uint32_t hp_nct = 0;
uint32_t hp_nex = 0;
try {
auto dmd = common_get_device_memory_data(
params_dft.model.path.c_str(), &mparams_dft, &cparams_dft,
devs, hp_ngl, hp_nct, hp_nex, GGML_LOG_LEVEL_ERROR);
GGML_ASSERT(!params_base.fit_params_target.empty());
size_t total = 0;
std::vector<ggml_backend_dev_t> tgt_devices = params.devices;
if (tgt_devices.empty()) {
for(size_t i = 0; i < ggml_backend_dev_count(); ++i) {
tgt_devices.push_back(ggml_backend_dev_get(i));
}
}
for (size_t j = 0; j < devs.size(); ++j) {
const size_t bytes =
(measure_model_bytes ? dmd[j].mb.model : 0) +
dmd[j].mb.context +
dmd[j].mb.compute;
total += bytes;
for (size_t i = 0; i < tgt_devices.size(); i++) {
if (tgt_devices[i] == devs[j]) {
SRV_DBG("[spec] adding %.2f MiB to fit_params_target for device %s\n",
bytes / (1024.0 * 1024.0), ggml_backend_dev_name(devs[j]));
params_base.fit_params_target[i] += bytes;
break;
}
}
}
SRV_INF("[spec] estimated memory usage of %s is %.2f MiB\n",
has_draft ? "draft model" : "MTP context",
total / (1024.0 * 1024.0));
} catch (const std::exception & e) {
SRV_ERR("[spec] failed to measure %s memory: %s\n",
has_draft ? "draft model" : "MTP context", e.what());
}
}
}
llama_init = common_init_from_params(params_base);
model_tgt = llama_init->model();