diff --git a/common/fit.cpp b/common/fit.cpp index c10cb7f08..668d892e9 100644 --- a/common/fit.cpp +++ b/common/fit.cpp @@ -26,7 +26,7 @@ class common_params_fit_exception : public std::runtime_error { using std::runtime_error::runtime_error; }; -static std::vector common_get_device_memory_data( +std::vector common_get_device_memory_data( const char * path_model, const llama_model_params * mparams, const llama_context_params * cparams, diff --git a/common/fit.h b/common/fit.h index e066092ec..643d34200 100644 --- a/common/fit.h +++ b/common/fit.h @@ -1,6 +1,11 @@ #pragma once #include "ggml.h" +#include "ggml-backend.h" +#include "llama.h" +#include "../src/llama-ext.h" + +#include enum common_params_fit_status { COMMON_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit @@ -30,3 +35,14 @@ void common_fit_print( struct llama_context_params * cparams); void common_memory_breakdown_print(const struct llama_context * ctx); + +// Load a model + context with no_alloc and return the per-device memory breakdown. +std::vector common_get_device_memory_data( + const char * path_model, + const struct llama_model_params * mparams, + const struct llama_context_params * cparams, + std::vector & devs, + uint32_t & hp_ngl, + uint32_t & hp_n_ctx_train, + uint32_t & hp_n_expert, + enum ggml_log_level log_level); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index b939e3b75..c3daafd0d 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -8,6 +8,7 @@ #include "build-info.h" #include "common.h" +#include "fit.h" #include "llama.h" #include "log.h" #include "sampling.h" @@ -775,7 +776,7 @@ private: for (auto & [dev, size] : mmproj_mem) { total += size; } - SRV_INF("[mtmd] estimated memory usage of mmproj is %.2f MiB\n", total / (1024.0 * 1024.0)); + SRV_INF("[mtmd] estimated worst-case memory usage of mmproj is %.2f MiB\n", total / (1024.0 * 1024.0)); GGML_ASSERT(!params_base.fit_params_target.empty()); for (auto & [dev, size] : mmproj_mem) { for (size_t i = 0; i < ggml_backend_dev_count(); i++) { @@ -793,6 +794,82 @@ private: } } + // optionally reserve VRAM for the draft / MTP context before fitting the target model + if (params_base.fit_params) { + const bool spec_mtp = std::find(params_base.speculative.types.begin(), + params_base.speculative.types.end(), + COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end(); + const bool has_draft = params_base.speculative.has_dft(); + + if (has_draft || spec_mtp) { + common_params params_dft = params_base; + bool measure_model_bytes = true; + + if (has_draft) { + const auto & params_spec = params_base.speculative.draft; + params_dft.devices = params_spec.devices; + params_dft.model = params_spec.mparams; + params_dft.n_gpu_layers = params_spec.n_gpu_layers; + params_dft.cache_type_k = params_spec.cache_type_k; + params_dft.cache_type_v = params_spec.cache_type_v; + params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides; + } else { + // MTP draft context lives on the target model, only context+compute are new + measure_model_bytes = false; + } + + auto mparams_dft = common_model_params_to_llama(params_dft); + auto cparams_dft = common_context_params_to_llama(params_dft); + if (spec_mtp) { + cparams_dft.ctx_type = LLAMA_CONTEXT_TYPE_MTP; + } + cparams_dft.n_rs_seq = 0; + + std::vector devs; + uint32_t hp_ngl = 0; + uint32_t hp_nct = 0; + uint32_t hp_nex = 0; + try { + auto dmd = common_get_device_memory_data( + params_dft.model.path.c_str(), &mparams_dft, &cparams_dft, + devs, hp_ngl, hp_nct, hp_nex, GGML_LOG_LEVEL_ERROR); + + GGML_ASSERT(!params_base.fit_params_target.empty()); + size_t total = 0; + + std::vector tgt_devices = params.devices; + + if (tgt_devices.empty()) { + for(size_t i = 0; i < ggml_backend_dev_count(); ++i) { + tgt_devices.push_back(ggml_backend_dev_get(i)); + } + } + + for (size_t j = 0; j < devs.size(); ++j) { + const size_t bytes = + (measure_model_bytes ? dmd[j].mb.model : 0) + + dmd[j].mb.context + + dmd[j].mb.compute; + total += bytes; + for (size_t i = 0; i < tgt_devices.size(); i++) { + if (tgt_devices[i] == devs[j]) { + SRV_DBG("[spec] adding %.2f MiB to fit_params_target for device %s\n", + bytes / (1024.0 * 1024.0), ggml_backend_dev_name(devs[j])); + params_base.fit_params_target[i] += bytes; + break; + } + } + } + SRV_INF("[spec] estimated memory usage of %s is %.2f MiB\n", + has_draft ? "draft model" : "MTP context", + total / (1024.0 * 1024.0)); + } catch (const std::exception & e) { + SRV_ERR("[spec] failed to measure %s memory: %s\n", + has_draft ? "draft model" : "MTP context", e.what()); + } + } + } + llama_init = common_init_from_params(params_base); model_tgt = llama_init->model();