From 06d26dfdff4097dc51eac20155371a9cfd53e094 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Fri, 29 May 2026 16:30:55 +0200 Subject: [PATCH] download: add option to skip_download (#23059) * download: add option to skip_download * fix * fix 2 * if file doesn't exist, respect skip_download flag --- common/arg.cpp | 74 +++++++++++++++++++--------------- common/arg.h | 7 +++- common/common.h | 3 +- common/download.cpp | 26 ++++++++---- common/download.h | 13 ++++-- tools/server/README.md | 15 +++++-- tools/server/server-models.cpp | 70 +++++++++++++++++--------------- tools/server/server-models.h | 1 + 8 files changed, 126 insertions(+), 83 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 51631765f..e0f6c6066 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -340,9 +340,7 @@ struct handle_model_result { }; static handle_model_result common_params_handle_model(struct common_params_model & model, - const std::string & bearer_token, - bool offline, - bool search_mtp = false) { + const common_download_opts & opts) { handle_model_result result; if (!model.docker_repo.empty()) { @@ -354,10 +352,9 @@ static handle_model_result common_params_handle_model(struct common_params_model model.hf_file = model.path; model.path = ""; } - common_download_opts opts; - opts.bearer_token = bearer_token; - opts.offline = offline; - auto download_result = common_download_model(model, opts, true, search_mtp); + common_download_opts hf_opts = opts; + hf_opts.download_mmproj = true; // also look for mmproj when downloading hf model + auto download_result = common_download_model(model, hf_opts); if (download_result.model_path.empty()) { throw std::runtime_error("failed to download model from Hugging Face"); @@ -382,9 +379,6 @@ static handle_model_result common_params_handle_model(struct common_params_model model.path = fs_get_cache_file(string_split(f, '/').back()); } - common_download_opts opts; - opts.bearer_token = bearer_token; - opts.offline = offline; auto download_result = common_download_model(model, opts); if (download_result.model_path.empty()) { throw std::runtime_error("failed to download model from " + model.url); @@ -441,35 +435,49 @@ static bool parse_bool_value(const std::string & value) { // CLI argument parsing functions // -void common_params_handle_models(common_params & params, llama_example curr_ex) { +bool common_params_handle_models(common_params & params, llama_example curr_ex) { const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(), params.speculative.types.end(), COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end(); - auto res = common_params_handle_model(params.model, params.hf_token, params.offline, spec_type_draft_mtp); - if (params.no_mmproj) { - params.mmproj = {}; - } else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) { - // optionally, handle mmproj model when -hf is specified - params.mmproj = res.mmproj; - } - // only download mmproj if the current example is using it - for (const auto & ex : mmproj_examples) { - if (curr_ex == ex) { - common_params_handle_model(params.mmproj, params.hf_token, params.offline); - break; + common_download_opts opts; + opts.bearer_token = params.hf_token; + opts.offline = params.offline; + opts.skip_download = params.skip_download; + opts.download_mtp = spec_type_draft_mtp; + + try { + auto res = common_params_handle_model(params.model, opts); + if (params.no_mmproj) { + params.mmproj = {}; + } else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) { + // optionally, handle mmproj model when -hf is specified + params.mmproj = res.mmproj; } + // only download mmproj if the current example is using it + for (const auto & ex : mmproj_examples) { + if (curr_ex == ex) { + common_params_handle_model(params.mmproj, opts); + break; + } + } + + // when --spec-type mtp is set and no draft model was provided explicitly, + // fall back to the MTP head discovered alongside the -hf model + if (spec_type_draft_mtp && res.found_mtp && + params.speculative.draft.mparams.path.empty() && + params.speculative.draft.mparams.hf_repo.empty() && + params.speculative.draft.mparams.url.empty()) { + params.speculative.draft.mparams.path = res.mtp.path; + } + common_params_handle_model(params.speculative.draft.mparams, opts); + common_params_handle_model(params.vocoder.model, opts); + return true; + } catch (const common_skip_download_exception &) { + return false; + } catch (const std::exception &) { + throw; } - // when --spec-type mtp is set and no draft model was provided explicitly, - // fall back to the MTP head discovered alongside the -hf model - if (spec_type_draft_mtp && res.found_mtp && - params.speculative.draft.mparams.path.empty() && - params.speculative.draft.mparams.hf_repo.empty() && - params.speculative.draft.mparams.url.empty()) { - params.speculative.draft.mparams.path = res.mtp.path; - } - common_params_handle_model(params.speculative.draft.mparams, params.hf_token, params.offline); - common_params_handle_model(params.vocoder.model, params.hf_token, params.offline); } static bool common_params_parse_ex(int argc, char ** argv, common_params_context & ctx_arg) { diff --git a/common/arg.h b/common/arg.h index 2a85f09f3..0010f2a9a 100644 --- a/common/arg.h +++ b/common/arg.h @@ -129,8 +129,11 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map & args); -// Populate model paths (main model, mmproj, etc) from -hf if necessary -void common_params_handle_models(common_params & params, llama_example curr_ex); +// populate model paths (main model, mmproj, etc) from -hf if necessary +// return true if the model is ready to use +// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc) +// if params.skip_download is true, no downloads will be attempted. return false if the model is invalid or missing (e.g. ETag check failed) +bool common_params_handle_models(common_params & params, llama_example curr_ex); // initialize argument parser context - used by test-arg-parser and preset common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); diff --git a/common/common.h b/common/common.h index 9855d3f36..99898800d 100644 --- a/common/common.h +++ b/common/common.h @@ -479,7 +479,7 @@ struct common_params { std::set model_alias; // model aliases // NOLINT std::set model_tags; // model tags (informational, not used for routing) // NOLINT - std::string hf_token = ""; // HF token // NOLINT + std::string hf_token = ""; // HF token (aka bearer token) // NOLINT std::string prompt = ""; // NOLINT std::string system_prompt = ""; // NOLINT std::string prompt_file = ""; // store the external prompt file name // NOLINT @@ -507,6 +507,7 @@ struct common_params { int32_t control_vector_layer_start = -1; // layer range for control vector int32_t control_vector_layer_end = -1; // layer range for control vector bool offline = false; + bool skip_download = false; // skip model file downloading int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line diff --git a/common/download.cpp b/common/download.cpp index 103bc408f..40f6eb780 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -292,6 +292,10 @@ static int common_download_file_single_online(const std::string & url, const bool file_exists = std::filesystem::exists(path); + if (!file_exists && opts.skip_download) { + return -2; // file is missing and download is disabled + } + if (file_exists && skip_etag) { LOG_DBG("%s: using cached file: %s\n", __func__, path.c_str()); return 304; // 304 Not Modified - fake cached response @@ -357,6 +361,10 @@ static int common_download_file_single_online(const std::string & url, LOG_DBG("%s: using cached file (same etag): %s\n", __func__, path.c_str()); return 304; // 304 Not Modified - fake cached response } + // pass this point, the file exists but is different from the server version, so we need to redownload it + if (opts.skip_download) { + return -2; // special code to indicate that the download was skipped due to etag mismatch + } if (remove(path.c_str()) != 0) { LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); return -1; @@ -775,13 +783,13 @@ static std::vector get_url_tasks(const common_params_model & mode } common_download_model_result common_download_model(const common_params_model & model, - const common_download_opts & opts, - bool download_mmproj, - bool download_mtp) { + const common_download_opts & opts) { common_download_model_result result; std::vector tasks; hf_plan hf; + bool download_mmproj = opts.download_mmproj; + bool download_mtp = opts.download_mtp; bool is_hf = !model.hf_repo.empty(); if (is_hf) { @@ -806,18 +814,22 @@ common_download_model_result common_download_model(const common_params_model & return result; } - std::vector> futures; + std::vector> futures; for (const auto & task : tasks) { futures.push_back(std::async(std::launch::async, [&task, &opts, is_hf]() { - int status = common_download_file_single(task.url, task.path, opts, is_hf); - return is_http_status_ok(status); + return common_download_file_single(task.url, task.path, opts, is_hf); } )); } for (auto & f : futures) { - if (!f.get()) { + int status = f.get(); + if (status == -2 && opts.skip_download) { + throw common_skip_download_exception(); + } + bool is_ok = is_http_status_ok(status); + if (!is_ok) { return {}; } } diff --git a/common/download.h b/common/download.h index 4a169ef77..ebeedd605 100644 --- a/common/download.h +++ b/common/download.h @@ -52,6 +52,9 @@ struct common_download_opts { std::string bearer_token; common_header_list headers; bool offline = false; + bool skip_download = false; // if true, only validation is performed, common_skip_download_exception may be thrown if the file is missing or invalid + bool download_mmproj = false; + bool download_mtp = false; common_download_callback * callback = nullptr; }; @@ -62,6 +65,11 @@ struct common_download_model_result { std::string mtp_path; }; +// throw if the file is missing or invalid (e.g. ETag check failed) +struct common_skip_download_exception : public std::runtime_error { + common_skip_download_exception() : std::runtime_error("skip download") {} +}; + // Download model from HuggingFace repo or URL // // input (via model struct): @@ -89,9 +97,7 @@ struct common_download_model_result { // returns result with model_path, mmproj_path and mtp_path (empty when not found / on failure) common_download_model_result common_download_model( const common_params_model & model, - const common_download_opts & opts = {}, - bool download_mmproj = false, - bool download_mtp = false + const common_download_opts & opts = {} ); // returns list of cached models @@ -99,6 +105,7 @@ std::vector common_list_cached_models(); // download single file from url to local path // returns status code or -1 on error +// returns -2 if the download was skipped due to ETag mismatch (file outdated, skip_download=true) // skip_etag: if true, don't read/write .etag files (for HF cache where filename is the hash) int common_download_file_single(const std::string & url, const std::string & path, diff --git a/tools/server/README.md b/tools/server/README.md index 7870e3091..87600d9be 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -1661,23 +1661,30 @@ Listing all models in cache. The model metadata will also include a field to ind { "data": [{ "id": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M", - "in_cache": true, "path": "/Users/REDACTED/Library/Caches/llama.cpp/ggml-org_gemma-3-4b-it-GGUF_gemma-3-4b-it-Q4_K_M.gguf", "status": { "value": "loaded", "args": ["llama-server", "-ctx", "4096"] }, + "architecture": { + "input_modalities": [ + "text", + "image" + ], + "output_modalities": [ + "text" + ] + }, ... }] } ``` Note: -1. For a local GGUF (stored offline in a custom directory), the model object will have `"in_cache": false`. -2. Adding `?reload=1` to the query params will refresh the list of models. The behavior is as follow: +1. Adding `?reload=1` to the query params will refresh the list of models. The behavior is as follow: - If a model is running but updated or removed from the source, it will be unloaded - If a model is not running, it will be added or updated according to the source -3. When the model is loaded, the info from `/v1/models` is forwarded to router's `/v1/models`. This includes metadata about the model and the runtime instance. +2. When the model is loaded, the info from `/v1/models` is forwarded to router's `/v1/models`. This includes metadata about the model and the runtime instance. The `status` object can be: diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 47b6c2a4e..49b0e423f 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -180,7 +180,8 @@ void server_model_meta::update_caps() { "LLAMA_ARG_HF_REPO", "LLAMA_ARG_HF_REPO_FILE", }); - params.offline = true; // avoid any unwanted network call during capability detection + params.offline = true; + // params.skip_download = true; // TODO: ideally, we should validate the model here, but it takes too much time common_params_handle_models(params, LLAMA_EXAMPLE_SERVER); if (params.mmproj.path.empty()) { multimodal = { false, false }; @@ -371,18 +372,19 @@ void server_models::load_models() { // FIRST LOAD: add all models, then unlock for autoloading for (const auto & [name, preset] : final_presets) { server_model_meta meta{ - /* preset */ preset, - /* name */ name, - /* aliases */ {}, - /* tags */ {}, - /* port */ 0, - /* status */ SERVER_MODEL_STATUS_UNLOADED, - /* last_used */ 0, - /* args */ std::vector(), - /* loaded_info */ {}, - /* exit_code */ 0, - /* stop_timeout */ DEFAULT_STOP_TIMEOUT, - /* multimodal */ mtmd_caps{false, false}, + /* preset */ preset, + /* name */ name, + /* aliases */ {}, + /* tags */ {}, + /* port */ 0, + /* status */ SERVER_MODEL_STATUS_UNLOADED, + /* last_used */ 0, + /* args */ std::vector(), + /* loaded_info */ {}, + /* exit_code */ 0, + /* stop_timeout */ DEFAULT_STOP_TIMEOUT, + /* multimodal */ mtmd_caps{false, false}, + /* need_download */ false, }; add_model(std::move(meta)); } @@ -524,18 +526,19 @@ void server_models::load_models() { for (const auto & [name, preset] : final_presets) { if (mapping.find(name) == mapping.end()) { server_model_meta meta{ - /* preset */ preset, - /* name */ name, - /* aliases */ {}, - /* tags */ {}, - /* port */ 0, - /* status */ SERVER_MODEL_STATUS_UNLOADED, - /* last_used */ 0, - /* args */ std::vector(), - /* loaded_info */ {}, - /* exit_code */ 0, - /* stop_timeout */ DEFAULT_STOP_TIMEOUT, - /* multimodal */ mtmd_caps{false, false}, + /* preset */ preset, + /* name */ name, + /* aliases */ {}, + /* tags */ {}, + /* port */ 0, + /* status */ SERVER_MODEL_STATUS_UNLOADED, + /* last_used */ 0, + /* args */ std::vector(), + /* loaded_info */ {}, + /* exit_code */ 0, + /* stop_timeout */ DEFAULT_STOP_TIMEOUT, + /* multimodal */ mtmd_caps{false, false}, + /* need_download */ false, }; add_model(std::move(meta)); newly_added.push_back(name); @@ -1263,14 +1266,15 @@ void server_models_routes::init_routes() { }; json model_info = json { - {"id", meta.name}, - {"aliases", meta.aliases}, - {"tags", meta.tags}, - {"object", "model"}, // for OAI-compat - {"owned_by", "llamacpp"}, // for OAI-compat - {"created", t}, // for OAI-compat - {"status", status}, - {"architecture", architecture}, + {"id", meta.name}, + {"aliases", meta.aliases}, + {"tags", meta.tags}, + {"object", "model"}, // for OAI-compat + {"owned_by", "llamacpp"}, // for OAI-compat + {"created", t}, // for OAI-compat + {"status", status}, + {"architecture", architecture}, + {"need_download", meta.need_download}, // TODO: add other fields, may require reading GGUF metadata }; diff --git a/tools/server/server-models.h b/tools/server/server-models.h index e96d76c91..2198589a7 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -67,6 +67,7 @@ struct server_model_meta { int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED) int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown mtmd_caps multimodal; // multimodal capabilities + bool need_download = false; // whether the model needs to be downloaded before loading bool is_ready() const { return status == SERVER_MODEL_STATUS_LOADED;