diff --git a/common/arg.cpp b/common/arg.cpp index 5debdea7e..474072692 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3,6 +3,7 @@ #include "chat.h" #include "common.h" #include "download.h" +#include "hf-cache.h" #include "json-schema-to-grammar.h" #include "log.h" #include "sampling.h" @@ -329,60 +330,48 @@ struct handle_model_result { common_params_model mmproj; }; -static handle_model_result common_params_handle_model( - struct common_params_model & model, - const std::string & bearer_token, - bool offline) { +static handle_model_result common_params_handle_model(struct common_params_model & model, + const std::string & bearer_token, + bool offline) { handle_model_result result; - // handle pre-fill default model path and url based on hf_repo and hf_file - { - if (!model.docker_repo.empty()) { // Handle Docker URLs by resolving them to local paths - model.path = common_docker_resolve_model(model.docker_repo); - model.name = model.docker_repo; // set name for consistency - } else if (!model.hf_repo.empty()) { - // short-hand to avoid specifying --hf-file -> default it to --model - if (model.hf_file.empty()) { - if (model.path.empty()) { - auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token, offline); - if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) { - exit(1); // error message already printed - } - model.name = model.hf_repo; // repo name with tag - model.hf_repo = auto_detected.repo; // repo name without tag - model.hf_file = auto_detected.ggufFile; - if (!auto_detected.mmprojFile.empty()) { - result.found_mmproj = true; - result.mmproj.hf_repo = model.hf_repo; - result.mmproj.hf_file = auto_detected.mmprojFile; - } - } else { - model.hf_file = model.path; - } - } - - std::string model_endpoint = get_model_endpoint(); - model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file; - // make sure model path is present (for caching purposes) - if (model.path.empty()) { - // this is to avoid different repo having same file name, or same file name in different subdirs - std::string filename = clean_file_name(model.hf_repo + "_" + model.hf_file); - model.path = fs_get_cache_file(filename); - } - - } else if (!model.url.empty()) { - if (model.path.empty()) { - auto f = string_split(model.url, '#').front(); - f = string_split(f, '?').front(); - model.path = fs_get_cache_file(string_split(f, '/').back()); - } + if (!model.docker_repo.empty()) { + model.path = common_docker_resolve_model(model.docker_repo); + model.name = model.docker_repo; + } else if (!model.hf_repo.empty()) { + // If -m was used with -hf, treat the model "path" as the hf_file to download + if (model.hf_file.empty() && !model.path.empty()) { + model.hf_file = model.path; + model.path = ""; } - } + common_download_model_opts opts; + opts.download_mmproj = true; + opts.offline = offline; + auto download_result = common_download_model(model, bearer_token, opts); - // then, download it if needed - if (!model.url.empty()) { - bool ok = common_download_model(model, bearer_token, offline); - if (!ok) { + if (download_result.model_path.empty()) { + LOG_ERR("error: failed to download model from Hugging Face\n"); + exit(1); + } + + model.name = model.hf_repo; + model.path = download_result.model_path; + + if (!download_result.mmproj_path.empty()) { + result.found_mmproj = true; + result.mmproj.path = download_result.mmproj_path; + } + } else if (!model.url.empty()) { + if (model.path.empty()) { + auto f = string_split(model.url, '#').front(); + f = string_split(f, '?').front(); + model.path = fs_get_cache_file(string_split(f, '/').back()); + } + + common_download_model_opts opts; + opts.offline = offline; + auto download_result = common_download_model(model, bearer_token, opts); + if (download_result.model_path.empty()) { LOG_ERR("error: failed to download model from %s\n", model.url.c_str()); exit(1); } @@ -542,6 +531,13 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context // parse the first time to get -hf option (used for remote preset) parse_cli_args(); + // TODO: Remove later + try { + hf_cache::migrate_old_cache_to_hf_cache(params.hf_token, params.offline); + } catch (const std::exception & e) { + LOG_WRN("HF cache migration failed: %s\n", e.what()); + } + // maybe handle remote preset if (!params.model.hf_repo.empty()) { std::string cli_hf_repo = params.model.hf_repo; @@ -1064,12 +1060,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"-cl", "--cache-list"}, "show list of models in cache", [](common_params &) { - printf("model cache directory: %s\n", fs_get_cache_directory().c_str()); auto models = common_list_cached_models(); printf("number of models in cache: %zu\n", models.size()); for (size_t i = 0; i < models.size(); i++) { - auto & model = models[i]; - printf("%4d. %s\n", (int) i + 1, model.to_string().c_str()); + printf("%4zu. %s\n", i + 1, models[i].to_string().c_str()); } exit(0); } diff --git a/common/chat-auto-parser-generator.cpp b/common/chat-auto-parser-generator.cpp index aa03aea5a..bf44091d6 100644 --- a/common/chat-auto-parser-generator.cpp +++ b/common/chat-auto-parser-generator.cpp @@ -112,8 +112,7 @@ common_peg_arena autoparser::build_parser(const generation_params & inputs) cons } else { parser = content.build_parser(ctx); } - parser = wrap_for_generation_prompt(p, parser, inputs, reasoning.start); - return parser; + return p.prefix(inputs.generation_prompt, reasoning.start) + parser; }); } diff --git a/common/chat-auto-parser-helpers.cpp b/common/chat-auto-parser-helpers.cpp index 3a7a5c13a..2499464cd 100644 --- a/common/chat-auto-parser-helpers.cpp +++ b/common/chat-auto-parser-helpers.cpp @@ -308,22 +308,6 @@ std::vector prune_whitespace_segments(const std::vector & segm return result; } -common_peg_parser wrap_for_generation_prompt(common_chat_peg_builder & p, - const common_peg_parser & prs, - const autoparser::generation_params & inputs, - const std::string & reasoning_start) { - auto parser = prs; - if (!inputs.generation_prompt.empty()) { - size_t end_pos = inputs.generation_prompt.size(); - if (!reasoning_start.empty() && inputs.generation_prompt.find(reasoning_start) != std::string::npos) { - end_pos = inputs.generation_prompt.find(reasoning_start); - } - std::string cut_genprompt = inputs.generation_prompt.substr(0, end_pos); - parser = p.literal(cut_genprompt) + parser; - } - return parser; -} - namespace autoparser { std::string apply_template(const common_chat_template & tmpl, const template_params & params) { diff --git a/common/chat-auto-parser-helpers.h b/common/chat-auto-parser-helpers.h index e13581e58..7cd031c4d 100644 --- a/common/chat-auto-parser-helpers.h +++ b/common/chat-auto-parser-helpers.h @@ -58,11 +58,6 @@ std::vector segmentize_markers(const std::string & text); // (MARKER, ""), (MARKER, "") ] std::vector prune_whitespace_segments(const std::vector & segments); -// Wrap parser with generation prompt parser -common_peg_parser wrap_for_generation_prompt(common_chat_peg_builder & p, - const common_peg_parser & prs, - const autoparser::generation_params & inputs, - const std::string & reasoning_start = {}); namespace autoparser { // Apply a template with the given parameters, returning the rendered string (empty on failure) diff --git a/common/chat-peg-parser.cpp b/common/chat-peg-parser.cpp index 5f7d422b4..07b487e15 100644 --- a/common/chat-peg-parser.cpp +++ b/common/chat-peg-parser.cpp @@ -802,6 +802,16 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys( return tool_choices; } +common_peg_parser common_chat_peg_builder::prefix(const std::string & s, const std::string & delimiter) { + if (s.empty()) { + return eps(); + } + if (delimiter.empty()) { + return literal(s); + } + return literal(s.substr(0, s.rfind(delimiter))); +} + common_peg_parser common_chat_peg_builder::standard_json_tools( const std::string & section_start, const std::string & section_end, diff --git a/common/chat-peg-parser.h b/common/chat-peg-parser.h index a497508d2..62402923c 100644 --- a/common/chat-peg-parser.h +++ b/common/chat-peg-parser.h @@ -82,6 +82,10 @@ class common_chat_peg_builder : public common_peg_parser_builder { common_peg_parser tool_arg_string_value(const common_peg_parser & p) { return tag(TOOL_ARG_STRING_VALUE, p); } common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_VALUE, p)); } + + // Return a parser that parses the prefix of a string, up to a given delimiter. + common_peg_parser prefix(const std::string & s, const std::string & delimiter = {}); + // Legacy-compatible helper for building standard JSON tool calls // Used by tests and manual parsers // name_key/args_key: JSON key names for function name and arguments diff --git a/common/chat.cpp b/common/chat.cpp index cab5b4c44..078e0cee2 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -887,14 +887,14 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ }; auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + auto generation_prompt = p.prefix(inputs.generation_prompt, "[THINK]"); auto reasoning = extract_reasoning ? p.optional("[THINK]" + p.reasoning(p.until("[/THINK]")) + "[/THINK]") : p.eps(); // Response format parser if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) { // Ministral wants to emit json surrounded by code fences - return wrap_for_generation_prompt(p, reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) << "```", - inputs, "[THINK]"); + return generation_prompt + (reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) << "```"); } // Tool call parser @@ -914,13 +914,12 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ auto max_calls = inputs.parallel_tool_calls ? -1 : 1; auto tool_calls = p.trigger_rule("tool-call", p.repeat("[TOOL_CALLS]" + tool_choice, min_calls, max_calls)); - return wrap_for_generation_prompt(p, reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls, - inputs, "[THINK]"); + return generation_prompt + (reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls); } // Content only parser include_grammar = false; - return wrap_for_generation_prompt(p, reasoning << p.content(p.rest()), inputs, "[THINK]"); + return generation_prompt + (reasoning << p.content(p.rest())); }); data.parser = parser.save(); @@ -1006,8 +1005,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp p.literal("<|channel|>final") + constraint + p.literal("<|message|>") + p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema))); - return wrap_for_generation_prompt(p, response_format | (analysis + p.zero_or_more(start + analysis) + start + response_format), - inputs, "<|channel|>"); + return p.zero_or_more(start + analysis) + start + response_format; } if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { @@ -1036,15 +1034,13 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp auto tool_call = p.trigger_rule("tool-call", tool_choice); if (inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED) { - return tool_call | ( any + p.zero_or_more(start + any) + start + tool_call); + return p.zero_or_more(start + any) + start + tool_call; } - return wrap_for_generation_prompt(p, tool_call | final_msg | (any + p.zero_or_more(start + any) + start + (tool_call | final_msg)), - inputs, "<|channel|>"); + return p.zero_or_more(start + any) + start + (tool_call | final_msg); } - return wrap_for_generation_prompt(p, final_msg | (any + p.zero_or_more(start + any) + start + final_msg), - inputs, "<|channel|>"); + return p.zero_or_more(start + any) + start + final_msg; }); data.parser = parser.save(); @@ -1095,11 +1091,12 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ // When no tools, content goes until end auto content_until_tool = p.literal("all\n") + p.content(p.until(">>>")); auto content_until_end = p.literal("all\n") + p.content(p.rest()); + auto generation_prompt = p.literal(inputs.generation_prompt); // If no tools or tool_choice is NONE, just parse content if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { // When no tools, just match the prefix and capture everything after - return wrap_for_generation_prompt(p, content_until_end + p.end(), inputs); + return generation_prompt + content_until_end + p.end(); } // Build tool call parsers for each available function @@ -1135,7 +1132,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ auto content_and_tool = content_until_tool + tool_choice; ret = p.choice({ content_and_tool, content_only, tool_choice }) + p.end(); } - return wrap_for_generation_prompt(p, ret, inputs); + return generation_prompt + ret; }); data.parser = parser.save(); @@ -1216,12 +1213,12 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp auto reasoning = extract_reasoning ? p.optional(THINK_START + p.reasoning( p.until_one_of({ THINK_END, "<|tool_calls_section_begin|>", "<|tool_call_begin|>" })) + p.optional(p.literal(THINK_END))) : p.eps(); + auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START); // Content only parser (no tools) if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { - return wrap_for_generation_prompt(p, reasoning + p.content(p.rest()) + end, - inputs, THINK_START); + return generation_prompt + reasoning + p.content(p.rest()) + end; } // Build tool call parsers for each available function @@ -1257,8 +1254,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp auto content_before_tools = p.content(p.until_one_of({ SECTION_BEGIN, CALL_BEGIN })); - return wrap_for_generation_prompt(p, reasoning + content_before_tools + tool_calls + end, - inputs, THINK_START); + return generation_prompt + reasoning + content_before_tools + tool_calls + end; }); data.parser = parser.save(); @@ -1316,6 +1312,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat data.thinking_end_tag = THINK_END; auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START); auto end = p.end(); auto reasoning = p.eps(); @@ -1324,8 +1321,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat } if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { - return wrap_for_generation_prompt(p, reasoning + p.content(p.rest()) + end, inputs, - THINK_START); + return generation_prompt + reasoning + p.content(p.rest()) + end; } auto tool_calls = p.rule("tool-calls", @@ -1337,8 +1333,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat auto content = p.content(p.until(TOOL_CALL_START)); - return wrap_for_generation_prompt(p, reasoning + content + tool_calls + end, inputs, - THINK_START); + return generation_prompt + reasoning + content + tool_calls + end; }); data.parser = parser.save(); @@ -1411,7 +1406,7 @@ static common_chat_params common_chat_params_init_gigachat_v3( ret = p.content(p.rest()); } - return wrap_for_generation_prompt(p, ret, inputs); + return p.literal(inputs.generation_prompt) + ret; }); data.parser = parser.save(); @@ -1636,7 +1631,7 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.generation_prompt = params.generation_prompt; auto parser = build_chat_peg_parser([¶ms](common_chat_peg_builder &p) { - return wrap_for_generation_prompt(p, p.content(p.rest()), params); + return p.prefix(params.generation_prompt) + p.content(p.rest()); }); data.parser = parser.save(); return data; diff --git a/common/download.cpp b/common/download.cpp index 17f930f5a..073dfa862 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -1,9 +1,9 @@ #include "arg.h" #include "common.h" -#include "gguf.h" // for reading GGUF splits #include "log.h" #include "download.h" +#include "hf-cache.h" #define JSON_ASSERT GGML_ASSERT #include @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -37,8 +38,6 @@ #endif #endif -#define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 - // isatty #if defined(_WIN32) #include @@ -53,31 +52,6 @@ using json = nlohmann::ordered_json; // // validate repo name format: owner/repo -static bool validate_repo_name(const std::string & repo) { - static const std::regex repo_regex(R"(^[A-Za-z0-9_.\-]+\/[A-Za-z0-9_.\-]+$)"); - return std::regex_match(repo, repo_regex); -} - -static std::string get_manifest_path(const std::string & repo, const std::string & tag) { - // we use "=" to avoid clashing with other component, while still being allowed on windows - std::string fname = "manifest=" + repo + "=" + tag + ".json"; - if (!validate_repo_name(repo)) { - throw std::runtime_error("error: repo name must be in the format 'owner/repo'"); - } - string_replace_all(fname, "/", "="); - return fs_get_cache_file(fname); -} - -static std::string read_file(const std::string & fname) { - std::ifstream file(fname); - if (!file) { - throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str())); - } - std::string content((std::istreambuf_iterator(file)), std::istreambuf_iterator()); - file.close(); - return content; -} - static void write_file(const std::string & fname, const std::string & content) { const std::string fname_tmp = fname + ".tmp"; std::ofstream file(fname_tmp); @@ -134,7 +108,7 @@ static bool is_http_status_ok(int status) { std::pair common_download_split_repo_tag(const std::string & hf_repo_with_tag) { auto parts = string_split(hf_repo_with_tag, ':'); - std::string tag = parts.size() > 1 ? parts.back() : "latest"; + std::string tag = parts.size() > 1 ? parts.back() : ""; std::string hf_repo = parts[0]; if (string_split(hf_repo, '/').size() != 2) { throw std::invalid_argument("error: invalid HF repo format, expected /[:quant]\n"); @@ -294,7 +268,8 @@ static bool common_pull_file(httplib::Client & cli, static int common_download_file_single_online(const std::string & url, const std::string & path, const std::string & bearer_token, - const common_header_list & custom_headers) { + const common_header_list & custom_headers, + bool skip_etag = false) { static const int max_attempts = 3; static const int retry_delay_seconds = 2; @@ -314,6 +289,11 @@ static int common_download_file_single_online(const std::string & url, const bool file_exists = std::filesystem::exists(path); + if (file_exists && skip_etag) { + LOG_INF("%s: using cached file: %s\n", __func__, path.c_str()); + return 304; // 304 Not Modified - fake cached response + } + std::string last_etag; if (file_exists) { last_etag = read_etag(path); @@ -365,6 +345,12 @@ static int common_download_file_single_online(const std::string & url, } } + { // silent + std::error_code ec; + std::filesystem::path p(path); + std::filesystem::create_directories(p.parent_path(), ec); + } + const std::string path_temporary = path + ".downloadInProgress"; int delay = retry_delay_seconds; @@ -395,7 +381,7 @@ static int common_download_file_single_online(const std::string & url, LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); return -1; } - if (!etag.empty()) { + if (!etag.empty() && !skip_etag) { write_etag(path, etag); } return head->status; @@ -444,9 +430,10 @@ int common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, bool offline, - const common_header_list & headers) { + const common_header_list & headers, + bool skip_etag) { if (!offline) { - return common_download_file_single_online(url, path, bearer_token, headers); + return common_download_file_single_online(url, path, bearer_token, headers, skip_etag); } if (!std::filesystem::exists(path)) { @@ -458,193 +445,293 @@ int common_download_file_single(const std::string & url, return 304; // Not Modified - fake cached response } -// download multiple files from remote URLs to local paths -// the input is a vector of pairs -static bool common_download_file_multiple(const std::vector> & urls, - const std::string & bearer_token, - bool offline, - const common_header_list & headers) { - // Prepare download in parallel - std::vector> futures_download; - futures_download.reserve(urls.size()); +struct gguf_split_info { + std::string prefix; // tag included + std::string tag; + int index; + int count; +}; - for (auto const & item : urls) { - futures_download.push_back( - std::async( - std::launch::async, - [&bearer_token, offline, &headers](const std::pair & it) -> bool { - const int http_status = common_download_file_single(it.first, it.second, bearer_token, offline, headers); - return is_http_status_ok(http_status); - }, - item - ) - ); +static gguf_split_info get_gguf_split_info(const std::string & path) { + static const std::regex re_split("^(.+)-([0-9]{5})-of-([0-9]{5})$", std::regex::icase); + static const std::regex re_tag("[-.]([A-Z0-9_]+)$", std::regex::icase); + std::smatch m; + + std::string prefix = path; + string_remove_suffix(prefix, ".gguf"); + + int index = 1; + int count = 1; + + if (std::regex_match(prefix, m, re_split)) { + index = std::stoi(m[2].str()); + count = std::stoi(m[3].str()); + prefix = m[1].str(); } - // Wait for all downloads to complete - for (auto & f : futures_download) { - if (!f.get()) { - return false; + std::string tag; + if (std::regex_search(prefix, m, re_tag)) { + tag = m[1].str(); + for (char & c : tag) { + c = std::toupper((unsigned char)c); } } - return true; + return {std::move(prefix), std::move(tag), index, count}; } -bool common_download_model(const common_params_model & model, - const std::string & bearer_token, - bool offline, - const common_header_list & headers) { - // Basic validation of the model.url - if (model.url.empty()) { - LOG_ERR("%s: invalid model url\n", __func__); - return false; +// Q4_0 -> 4, F16 -> 16, NVFP4 -> 4, Q8_K_M -> 8, etc +static int extract_quant_bits(const std::string & filename) { + auto split = get_gguf_split_info(filename); + + auto pos = split.tag.find_first_of("0123456789"); + if (pos == std::string::npos) { + return 0; } - const int http_status = common_download_file_single(model.url, model.path, bearer_token, offline, headers); - if (!is_http_status_ok(http_status)) { - return false; - } - - // check for additional GGUFs split to download - int n_split = 0; - { - struct gguf_init_params gguf_params = { - /*.no_alloc = */ true, - /*.ctx = */ NULL, - }; - auto * ctx_gguf = gguf_init_from_file(model.path.c_str(), gguf_params); - if (!ctx_gguf) { - LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, model.path.c_str()); - return false; - } - - auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT); - if (key_n_split >= 0) { - n_split = gguf_get_val_u16(ctx_gguf, key_n_split); - } - - gguf_free(ctx_gguf); - } - - if (n_split > 1) { - char split_prefix[PATH_MAX] = {0}; - char split_url_prefix[LLAMA_MAX_URL_LENGTH] = {0}; - - // Verify the first split file format - // and extract split URL and PATH prefixes - { - if (!llama_split_prefix(split_prefix, sizeof(split_prefix), model.path.c_str(), 0, n_split)) { - LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, model.path.c_str(), n_split); - return false; - } - - if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model.url.c_str(), 0, n_split)) { - LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model.url.c_str(), n_split); - return false; - } - } - - std::vector> urls; - for (int idx = 1; idx < n_split; idx++) { - char split_path[PATH_MAX] = {0}; - llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split); - - char split_url[LLAMA_MAX_URL_LENGTH] = {0}; - llama_split_path(split_url, sizeof(split_url), split_url_prefix, idx, n_split); - - if (std::string(split_path) == model.path) { - continue; // skip the already downloaded file - } - - urls.push_back({split_url, split_path}); - } - - // Download in parallel - common_download_file_multiple(urls, bearer_token, offline, headers); - } - - return true; + return std::stoi(split.tag.substr(pos)); } -common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, - const std::string & bearer_token, - bool offline, - const common_header_list & custom_headers) { - // the returned hf_repo is without tag - auto [hf_repo, tag] = common_download_split_repo_tag(hf_repo_with_tag); +static hf_cache::hf_files get_split_files(const hf_cache::hf_files & files, + const hf_cache::hf_file & file) { + auto split = get_gguf_split_info(file.path); - std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag; - - // headers - common_header_list headers = custom_headers; - headers.push_back({"Accept", "application/json"}); - if (!bearer_token.empty()) { - headers.push_back({"Authorization", "Bearer " + bearer_token}); + if (split.count <= 1) { + return {file}; } - // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response - // User-Agent header is already set in common_remote_get_content, no need to set it here + hf_cache::hf_files result; - // make the request - common_remote_params params; - params.headers = headers; - long res_code = 0; - std::string res_str; - bool use_cache = false; - std::string cached_response_path = get_manifest_path(hf_repo, tag); - if (!offline) { - try { - auto res = common_remote_get_content(url, params); - res_code = res.first; - res_str = std::string(res.second.data(), res.second.size()); - } catch (const std::exception & e) { - LOG_WRN("error: failed to get manifest at %s: %s\n", url.c_str(), e.what()); + for (const auto & f : files) { + auto split_f = get_gguf_split_info(f.path); + if (split_f.count == split.count && split_f.prefix == split.prefix) { + result.push_back(f); } } - if (res_code == 0) { - if (std::filesystem::exists(cached_response_path)) { - LOG_WRN("trying to read manifest from cache: %s\n", cached_response_path.c_str()); - res_str = read_file(cached_response_path); - res_code = 200; - use_cache = true; - } else { - throw std::runtime_error( - offline ? "error: failed to get manifest (offline mode)" - : "error: failed to get manifest (check your internet connection)"); + return result; +} + +static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files, + const std::string & model) { + hf_cache::hf_file best; + size_t best_depth = 0; + int best_diff = 0; + bool found = false; + + auto model_bits = extract_quant_bits(model); + auto model_parts = string_split(model, '/'); + auto model_dir = model_parts.end() - 1; + + for (const auto & f : files) { + if (!string_ends_with(f.path, ".gguf") || + f.path.find("mmproj") == std::string::npos) { + continue; + } + + auto mmproj_parts = string_split(f.path, '/'); + auto mmproj_dir = mmproj_parts.end() - 1; + + auto [_, dir] = std::mismatch(model_parts.begin(), model_dir, + mmproj_parts.begin(), mmproj_dir); + if (dir != mmproj_dir) { + continue; + } + + size_t depth = dir - mmproj_parts.begin(); + auto bits = extract_quant_bits(f.path); + auto diff = std::abs(bits - model_bits); + + if (!found || depth > best_depth || (depth == best_depth && diff < best_diff)) { + best = f; + best_depth = depth; + best_diff = diff; + found = true; } } - std::string ggufFile; - std::string mmprojFile; + return best; +} - if (res_code == 200 || res_code == 304) { - try { - auto j = json::parse(res_str); +static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files, + const std::string & tag) { + std::vector tags; - if (j.contains("ggufFile") && j["ggufFile"].contains("rfilename")) { - ggufFile = j["ggufFile"]["rfilename"].get(); - } - if (j.contains("mmprojFile") && j["mmprojFile"].contains("rfilename")) { - mmprojFile = j["mmprojFile"]["rfilename"].get(); - } - } catch (const std::exception & e) { - throw std::runtime_error(std::string("error parsing manifest JSON: ") + e.what()); - } - if (!use_cache) { - // if not using cached response, update the cache file - write_file(cached_response_path, res_str); - } - } else if (res_code == 401) { - throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token"); + if (!tag.empty()) { + tags.push_back(tag); } else { - throw std::runtime_error(string_format("error from HF API (%s), response code: %ld, data: %s", url.c_str(), res_code, res_str.c_str())); + tags = {"Q4_K_M", "Q4_0"}; } - // check response - if (ggufFile.empty()) { - throw std::runtime_error("error: model does not have ggufFile"); + for (const auto & t : tags) { + std::regex pattern(t + "[.-]", std::regex::icase); + for (const auto & f : files) { + if (string_ends_with(f.path, ".gguf") && + f.path.find("mmproj") == std::string::npos && + std::regex_search(f.path, pattern)) { + return f; + } + } } - return { hf_repo, ggufFile, mmprojFile }; + for (const auto & f : files) { + if (string_ends_with(f.path, ".gguf") && + f.path.find("mmproj") == std::string::npos) { + return f; + } + } + + return {}; +} + +static void list_available_gguf_files(const hf_cache::hf_files & files) { + LOG_INF("Available GGUF files:\n"); + for (const auto & f : files) { + if (string_ends_with(f.path, ".gguf")) { + LOG_INF(" - %s\n", f.path.c_str()); + } + } +} + +struct hf_plan { + hf_cache::hf_files model_files; + hf_cache::hf_file mmproj; +}; + +static hf_plan get_hf_plan(const common_params_model & model, + const std::string & token, + const common_download_model_opts & opts) { + hf_plan plan; + hf_cache::hf_files all; + + auto [repo, tag] = common_download_split_repo_tag(model.hf_repo); + + if (!opts.offline) { + all = hf_cache::get_repo_files(repo, token); + } + if (all.empty()) { + all = hf_cache::get_cached_files(repo); + } + if (all.empty()) { + return plan; + } + + hf_cache::hf_file primary; + + if (!model.hf_file.empty()) { + for (const auto & f : all) { + if (f.path == model.hf_file) { + primary = f; + break; + } + } + if (primary.path.empty()) { + LOG_ERR("%s: file '%s' not found in repository\n", __func__, model.hf_file.c_str()); + list_available_gguf_files(all); + return plan; + } + } else { + primary = find_best_model(all, tag); + if (primary.path.empty()) { + LOG_ERR("%s: no GGUF files found in repository %s\n", __func__, repo.c_str()); + list_available_gguf_files(all); + return plan; + } + } + + plan.model_files = get_split_files(all, primary); + + if (opts.download_mmproj) { + plan.mmproj = find_best_mmproj(all, primary.path); + } + + return plan; +} + +struct download_task { + std::string url; + std::string path; +}; + +static std::vector get_url_tasks(const common_params_model & model) { + auto split = get_gguf_split_info(model.url); + + if (split.count <= 1) { + return {{model.url, model.path}}; + } + + auto filename = split.prefix; + if (auto pos = split.prefix.rfind('/'); pos != std::string::npos) { + filename = split.prefix.substr(pos + 1); + } + + auto parent_path = std::filesystem::path(model.path).parent_path(); + auto prefix_path = (parent_path / filename).string(); + + std::vector tasks; + for (int i = 1; i <= split.count; i++) { + auto suffix = string_format("-%05d-of-%05d.gguf", i, split.count); + tasks.push_back({split.prefix + suffix, prefix_path + suffix}); + } + return tasks; +} + +common_download_model_result common_download_model(const common_params_model & model, + const std::string & bearer_token, + const common_download_model_opts & opts, + const common_header_list & headers) { + common_download_model_result result; + std::vector tasks; + hf_plan hf; + + bool is_hf = !model.hf_repo.empty(); + + if (is_hf) { + hf = get_hf_plan(model, bearer_token, opts); + for (const auto & f : hf.model_files) { + tasks.push_back({f.url, f.local_path}); + } + if (!hf.mmproj.path.empty()) { + tasks.push_back({hf.mmproj.url, hf.mmproj.local_path}); + } + } else if (!model.url.empty()) { + tasks = get_url_tasks(model); + } else { + result.model_path = model.path; + return result; + } + + if (tasks.empty()) { + return result; + } + + std::vector> futures; + for (const auto & task : tasks) { + futures.push_back(std::async(std::launch::async, + [&task, &bearer_token, offline = opts.offline, &headers, is_hf]() { + int status = common_download_file_single(task.url, task.path, bearer_token, offline, headers, is_hf); + return is_http_status_ok(status); + } + )); + } + + for (auto & f : futures) { + if (!f.get()) { + return {}; + } + } + + if (is_hf) { + for (const auto & f : hf.model_files) { + hf_cache::finalize_file(f); + } + result.model_path = hf.model_files[0].final_path; + + if (!hf.mmproj.path.empty()) { + result.mmproj_path = hf_cache::finalize_file(hf.mmproj); + } + } else { + result.model_path = model.path; + } + + return result; } // @@ -793,28 +880,21 @@ int common_download_file_single(const std::string &, #endif // defined(LLAMA_USE_HTTPLIB) std::vector common_list_cached_models() { - std::vector models; - const std::string cache_dir = fs_get_cache_directory(); - const std::vector files = fs_list(cache_dir, false); - for (const auto & file : files) { - if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) { - common_cached_model_info model_info; - model_info.manifest_path = file.path; - std::string fname = file.name; - string_replace_all(fname, ".json", ""); // remove extension - auto parts = string_split(fname, '='); - if (parts.size() == 4) { - // expect format: manifest==== - model_info.user = parts[1]; - model_info.model = parts[2]; - model_info.tag = parts[3]; - } else { - // invalid format - continue; - } - model_info.size = 0; // TODO: get GGUF size, not manifest size - models.push_back(model_info); + std::unordered_set seen; + std::vector result; + + auto files = hf_cache::get_cached_files(); + + for (const auto & f : files) { + auto split = get_gguf_split_info(f.path); + if (split.index != 1 || split.tag.empty() || + split.prefix.find("mmproj") != std::string::npos) { + continue; + } + if (seen.insert(f.repo_id + ":" + split.tag).second) { + result.push_back({f.repo_id, split.tag}); } } - return models; + + return result; } diff --git a/common/download.h b/common/download.h index 1c1d8e6db..0a933521f 100644 --- a/common/download.h +++ b/common/download.h @@ -17,54 +17,60 @@ struct common_remote_params { // get remote file content, returns std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params); -// split HF repo with tag into -// for example: "user/model:tag" -> <"user/model", "tag"> -// if tag is not present, default to "latest" -// example: "user/model" -> <"user/model", "latest"> +// split HF repo with tag into , for example: +// - "ggml-org/models:F16" -> <"ggml-org/models", "F16"> +// tag is optional and can be empty std::pair common_download_split_repo_tag(const std::string & hf_repo_with_tag); +// Result of common_list_cached_models struct common_cached_model_info { - std::string manifest_path; - std::string user; - std::string model; + std::string repo; std::string tag; - size_t size = 0; // GGUF size in bytes - // return string representation like "user/model:tag" - // if tag is "latest", it will be omitted std::string to_string() const { - return user + "/" + model + (tag == "latest" ? "" : ":" + tag); + return repo + ":" + tag; } }; -struct common_hf_file_res { - std::string repo; // repo name with ":tag" removed - std::string ggufFile; - std::string mmprojFile; +// Options for common_download_model +struct common_download_model_opts { + bool download_mmproj = false; + bool offline = false; }; -/** - * Allow getting the HF file from the HF repo with tag (like ollama), for example: - * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4 - * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M - * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s - * Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo) - * - * Return pair of (with "repo" already having tag removed) - * - * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files. - */ -common_hf_file_res common_get_hf_file( - const std::string & hf_repo_with_tag, - const std::string & bearer_token, - bool offline, - const common_header_list & headers = {} -); +// Result of common_download_model +struct common_download_model_result { + std::string model_path; + std::string mmproj_path; +}; -// returns true if download succeeded -bool common_download_model( +// Download model from HuggingFace repo or URL +// +// input (via model struct): +// - model.hf_repo: HF repo with optional tag, see common_download_split_repo_tag +// - model.hf_file: specific file in the repo (requires hf_repo) +// - model.url: simple download (used if hf_repo is empty) +// - model.path: local file path +// +// tag matching (for HF repos without model.hf_file): +// - if tag is specified, searches for GGUF matching that quantization +// - if no tag, searches for Q4_K_M, then Q4_0, then first available GGUF +// +// split GGUF: multi-part files like "model-00001-of-00003.gguf" are automatically +// detected and all parts are downloaded +// +// caching: +// - HF repos: uses HuggingFace cache +// - URLs: uses ETag-based caching +// +// when opts.offline=true, no network requests are made +// when download_mmproj=true, searches for mmproj in same directory as model or any parent directory +// then with the closest quantization bits +// +// returns result with model_path and mmproj_path (empty on failure) +common_download_model_result common_download_model( const common_params_model & model, const std::string & bearer_token, - bool offline, + const common_download_model_opts & opts = {}, const common_header_list & headers = {} ); @@ -73,11 +79,13 @@ std::vector common_list_cached_models(); // download single file from url to local path // returns status code or -1 on error +// skip_etag: if true, don't read/write .etag files (for HF cache where filename is the hash) int common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, bool offline, - const common_header_list & headers = {}); + const common_header_list & headers = {}, + bool skip_etag = false); // resolve and download model from Docker registry // return local path to downloaded model file diff --git a/common/hf-cache.cpp b/common/hf-cache.cpp new file mode 100644 index 000000000..ce66f6467 --- /dev/null +++ b/common/hf-cache.cpp @@ -0,0 +1,644 @@ +#include "hf-cache.h" + +#include "common.h" +#include "log.h" +#include "http.h" + +#define JSON_ASSERT GGML_ASSERT +#include + +#include +#include +#include +#include // migration only +#include +#include +#include + +namespace nl = nlohmann; + +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +#define NOMINMAX +#endif +#define HOME_DIR "USERPROFILE" +#include +#else +#define HOME_DIR "HOME" +#endif + +namespace hf_cache { + +namespace fs = std::filesystem; + +static fs::path get_cache_directory() { + static const fs::path cache = []() { + struct { + const char * var; + fs::path path; + } entries[] = { + {"HF_HUB_CACHE", fs::path()}, + {"HUGGINGFACE_HUB_CACHE", fs::path()}, + {"HF_HOME", fs::path("hub")}, + {"XDG_CACHE_HOME", fs::path("huggingface") / "hub"}, + {HOME_DIR, fs::path(".cache") / "huggingface" / "hub"} + }; + for (const auto & entry : entries) { + if (auto * p = std::getenv(entry.var); p && *p) { + fs::path base(p); + return entry.path.empty() ? base : base / entry.path; + } + } + throw std::runtime_error("Failed to determine HF cache directory"); + }(); + + return cache; +} + +static std::string folder_name_to_repo(const std::string & folder) { + constexpr std::string_view prefix = "models--"; + if (folder.rfind(prefix, 0)) { + return {}; + } + std::string result = folder.substr(prefix.length()); + string_replace_all(result, "--", "/"); + return result; +} + +static std::string repo_to_folder_name(const std::string & repo_id) { + constexpr std::string_view prefix = "models--"; + std::string result = std::string(prefix) + repo_id; + string_replace_all(result, "/", "--"); + return result; +} + +static fs::path get_repo_path(const std::string & repo_id) { + return get_cache_directory() / repo_to_folder_name(repo_id); +} + +static bool is_hex_char(const char c) { + return (c >= 'A' && c <= 'F') || + (c >= 'a' && c <= 'f') || + (c >= '0' && c <= '9'); +} + +static bool is_hex_string(const std::string & s, size_t expected_len) { + if (s.length() != expected_len) { + return false; + } + for (const char c : s) { + if (!is_hex_char(c)) { + return false; + } + } + return true; +} + +static bool is_alphanum(const char c) { + return (c >= 'A' && c <= 'Z') || + (c >= 'a' && c <= 'z') || + (c >= '0' && c <= '9'); +} + +static bool is_special_char(char c) { + return c == '/' || c == '.' || c == '-'; +} + +// base chars [A-Za-z0-9_] are always valid +// special chars [/.-] must be surrounded by base chars +// exactly one '/' required +static bool is_valid_repo_id(const std::string & repo_id) { + if (repo_id.empty() || repo_id.length() > 256) { + return false; + } + int slash = 0; + bool special = true; + + for (const char c : repo_id) { + if (is_alphanum(c) || c == '_') { + special = false; + } else if (is_special_char(c)) { + if (special) { + return false; + } + slash += (c == '/'); + special = true; + } else { + return false; + } + } + return !special && slash == 1; +} + +static bool is_valid_hf_token(const std::string & token) { + if (token.length() < 37 || token.length() > 256 || + !string_starts_with(token, "hf_")) { + return false; + } + for (size_t i = 3; i < token.length(); ++i) { + if (!is_alphanum(token[i])) { + return false; + } + } + return true; +} + +static bool is_valid_commit(const std::string & hash) { + return is_hex_string(hash, 40); +} + +static bool is_valid_oid(const std::string & oid) { + return is_hex_string(oid, 40) || is_hex_string(oid, 64); +} + +static bool is_valid_subpath(const fs::path & path, const fs::path & subpath) { + if (subpath.is_absolute()) { + return false; // never do a / b with b absolute + } + auto b = fs::absolute(path).lexically_normal(); + auto t = (b / subpath).lexically_normal(); + auto [b_end, _] = std::mismatch(b.begin(), b.end(), t.begin(), t.end()); + + return b_end == b.end(); +} + +static void safe_write_file(const fs::path & path, const std::string & data) { + fs::path path_tmp = path.string() + ".tmp"; + + if (path.has_parent_path()) { + fs::create_directories(path.parent_path()); + } + + std::ofstream file(path_tmp); + file << data; + file.close(); + + std::error_code ec; + + if (!file.fail()) { + fs::rename(path_tmp, path, ec); + } + if (file.fail() || ec) { + fs::remove(path_tmp, ec); + throw std::runtime_error("failed to write file: " + path.string()); + } +} + +static nl::json api_get(const std::string & url, + const std::string & token) { + auto [cli, parts] = common_http_client(url); + + httplib::Headers headers = { + {"User-Agent", "llama-cpp/" + build_info}, + {"Accept", "application/json"} + }; + + if (is_valid_hf_token(token)) { + headers.emplace("Authorization", "Bearer " + token); + } else if (!token.empty()) { + LOG_WRN("%s: invalid token, authentication disabled\n", __func__); + } + + if (auto res = cli.Get(parts.path, headers)) { + auto body = res->body; + + if (res->status == 200) { + return nl::json::parse(res->body); + } + try { + body = nl::json::parse(res->body)["error"].get(); + } catch (...) { } + + throw std::runtime_error("GET failed (" + std::to_string(res->status) + "): " + body); + } else { + throw std::runtime_error("HTTPLIB failed: " + httplib::to_string(res.error())); + } +} + +static std::string get_repo_commit(const std::string & repo_id, + const std::string & token) { + try { + auto endpoint = get_model_endpoint(); + auto json = api_get(endpoint + "api/models/" + repo_id + "/refs", token); + + if (!json.is_object() || + !json.contains("branches") || !json["branches"].is_array()) { + LOG_WRN("%s: missing 'branches' for '%s'\n", __func__, repo_id.c_str()); + return {}; + } + + fs::path refs_path = get_repo_path(repo_id) / "refs"; + std::string name; + std::string commit; + + for (const auto & branch : json["branches"]) { + if (!branch.is_object() || + !branch.contains("name") || !branch["name"].is_string() || + !branch.contains("targetCommit") || !branch["targetCommit"].is_string()) { + continue; + } + std::string _name = branch["name"].get(); + std::string _commit = branch["targetCommit"].get(); + + if (!is_valid_subpath(refs_path, _name)) { + LOG_WRN("%s: skip invalid branch: %s\n", __func__, _name.c_str()); + continue; + } + if (!is_valid_commit(_commit)) { + LOG_WRN("%s: skip invalid commit: %s\n", __func__, _commit.c_str()); + continue; + } + + if (_name == "main") { + name = _name; + commit = _commit; + break; + } + + if (name.empty() || commit.empty()) { + name = _name; + commit = _commit; + } + } + + if (name.empty() || commit.empty()) { + LOG_WRN("%s: no valid branch for '%s'\n", __func__, repo_id.c_str()); + return {}; + } + + safe_write_file(refs_path / name, commit); + return commit; + + } catch (const nl::json::exception & e) { + LOG_ERR("%s: JSON error: %s\n", __func__, e.what()); + } catch (const std::exception & e) { + LOG_ERR("%s: error: %s\n", __func__, e.what()); + } + return {}; +} + +hf_files get_repo_files(const std::string & repo_id, + const std::string & token) { + if (!is_valid_repo_id(repo_id)) { + LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str()); + return {}; + } + + std::string commit = get_repo_commit(repo_id, token); + if (commit.empty()) { + LOG_WRN("%s: failed to resolve commit for %s\n", __func__, repo_id.c_str()); + return {}; + } + + fs::path blobs_path = get_repo_path(repo_id) / "blobs"; + fs::path commit_path = get_repo_path(repo_id) / "snapshots" / commit; + + hf_files files; + + try { + auto endpoint = get_model_endpoint(); + auto json = api_get(endpoint + "api/models/" + repo_id + "/tree/" + commit + "?recursive=true", token); + + if (!json.is_array()) { + LOG_WRN("%s: response is not an array for '%s'\n", __func__, repo_id.c_str()); + return {}; + } + + for (const auto & item : json) { + if (!item.is_object() || + !item.contains("type") || !item["type"].is_string() || item["type"] != "file" || + !item.contains("path") || !item["path"].is_string()) { + continue; + } + + hf_file file; + file.repo_id = repo_id; + file.path = item["path"].get(); + + if (!is_valid_subpath(commit_path, file.path)) { + LOG_WRN("%s: skip invalid path: %s\n", __func__, file.path.c_str()); + continue; + } + + if (item.contains("lfs") && item["lfs"].is_object()) { + if (item["lfs"].contains("oid") && item["lfs"]["oid"].is_string()) { + file.oid = item["lfs"]["oid"].get(); + } + } else if (item.contains("oid") && item["oid"].is_string()) { + file.oid = item["oid"].get(); + } + + if (!file.oid.empty() && !is_valid_oid(file.oid)) { + LOG_WRN("%s: skip invalid oid: %s\n", __func__, file.oid.c_str()); + continue; + } + + file.url = endpoint + repo_id + "/resolve/" + commit + "/" + file.path; + + fs::path final_path = commit_path / file.path; + file.final_path = final_path.string(); + + if (!file.oid.empty() && !fs::exists(final_path)) { + fs::path local_path = blobs_path / file.oid; + file.local_path = local_path.string(); + } else { + file.local_path = file.final_path; + } + + files.push_back(file); + } + } catch (const nl::json::exception & e) { + LOG_ERR("%s: JSON error: %s\n", __func__, e.what()); + } catch (const std::exception & e) { + LOG_ERR("%s: error: %s\n", __func__, e.what()); + } + return files; +} + +static std::string get_cached_ref(const fs::path & repo_path) { + fs::path refs_path = repo_path / "refs"; + if (!fs::is_directory(refs_path)) { + return {}; + } + std::string fallback; + + for (const auto & entry : fs::directory_iterator(refs_path)) { + if (!entry.is_regular_file()) { + continue; + } + std::ifstream f(entry.path()); + std::string commit; + if (!f || !std::getline(f, commit) || commit.empty()) { + continue; + } + if (!is_valid_commit(commit)) { + LOG_WRN("%s: skip invalid commit: %s\n", __func__, commit.c_str()); + continue; + } + if (entry.path().filename() == "main") { + return commit; + } + if (fallback.empty()) { + fallback = commit; + } + } + return fallback; +} + +hf_files get_cached_files(const std::string & repo_id) { + fs::path cache_dir = get_cache_directory(); + if (!fs::exists(cache_dir)) { + return {}; + } + + if (!repo_id.empty() && !is_valid_repo_id(repo_id)) { + LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str()); + return {}; + } + + hf_files files; + + for (const auto & repo : fs::directory_iterator(cache_dir)) { + if (!repo.is_directory()) { + continue; + } + fs::path snapshots_path = repo.path() / "snapshots"; + + if (!fs::exists(snapshots_path)) { + continue; + } + std::string _repo_id = folder_name_to_repo(repo.path().filename().string()); + + if (!is_valid_repo_id(_repo_id)) { + continue; + } + if (!repo_id.empty() && _repo_id != repo_id) { + continue; + } + std::string commit = get_cached_ref(repo.path()); + fs::path commit_path = snapshots_path / commit; + + if (commit.empty() || !fs::is_directory(commit_path)) { + continue; + } + for (const auto & entry : fs::recursive_directory_iterator(commit_path)) { + if (!entry.is_regular_file() && !entry.is_symlink()) { + continue; + } + fs::path path = entry.path().lexically_relative(commit_path); + + if (!path.empty()) { + hf_file file; + file.repo_id = _repo_id; + file.path = path.generic_string(); + file.local_path = entry.path().string(); + file.final_path = file.local_path; + files.push_back(std::move(file)); + } + } + } + + return files; +} + +std::string finalize_file(const hf_file & file) { + static std::atomic symlinks_disabled{false}; + + std::error_code ec; + fs::path local_path(file.local_path); + fs::path final_path(file.final_path); + + if (local_path == final_path || fs::exists(final_path, ec)) { + return file.final_path; + } + + if (!fs::exists(local_path, ec)) { + return file.final_path; + } + + fs::create_directories(final_path.parent_path(), ec); + + if (!symlinks_disabled) { + fs::path target = fs::relative(local_path, final_path.parent_path(), ec); + if (!ec) { + fs::create_symlink(target, final_path, ec); + } + if (!ec) { + return file.final_path; + } + } + + if (!symlinks_disabled.exchange(true)) { + LOG_WRN("%s: failed to create symlink: %s\n", __func__, ec.message().c_str()); + LOG_WRN("%s: switching to degraded mode\n", __func__); + } + + fs::rename(local_path, final_path, ec); + if (ec) { + LOG_WRN("%s: failed to move file to snapshots: %s\n", __func__, ec.message().c_str()); + fs::copy(local_path, final_path, ec); + if (ec) { + LOG_ERR("%s: failed to copy file to snapshots: %s\n", __func__, ec.message().c_str()); + } + } + return file.final_path; +} + +// delete everything after this line, one day + +static std::pair parse_manifest_name(std::string & filename) { + static const std::regex re(R"(^manifest=([^=]+)=([^=]+)=.*\.json$)"); + std::smatch match; + if (std::regex_match(filename, match, re)) { + return {match[1].str(), match[2].str()}; + } + return {}; +} + +static std::string make_old_cache_filename(const std::string & owner, + const std::string & repo, + const std::string & filename) { + auto result = owner + "_" + repo + "_" + filename; + string_replace_all(result, "/", "_"); + return result; +} + +static bool migrate_single_file(const fs::path & old_cache, + const std::string & owner, + const std::string & repo, + const nl::json & node, + const hf_files & files) { + + if (!node.contains("rfilename") || + !node.contains("lfs") || + !node["lfs"].contains("sha256")) { + return false; + } + + std::string path = node["rfilename"]; + std::string sha256 = node["lfs"]["sha256"]; + + const hf_file * file_info = nullptr; + for (const auto & f : files) { + if (f.path == path) { + file_info = &f; + break; + } + } + + std::string old_filename = make_old_cache_filename(owner, repo, path); + fs::path old_path = old_cache / old_filename; + fs::path etag_path = old_path.string() + ".etag"; + + if (!fs::exists(old_path)) { + if (fs::exists(etag_path)) { + LOG_WRN("%s: %s is orphan, deleting...\n", __func__, etag_path.string().c_str()); + fs::remove(etag_path); + } + return false; + } + + bool delete_old_path = false; + + if (!file_info) { + LOG_WRN("%s: %s not found in current repo, deleting...\n", __func__, old_filename.c_str()); + delete_old_path = true; + } else if (!sha256.empty() && !file_info->oid.empty() && sha256 != file_info->oid) { + LOG_WRN("%s: %s is not up to date (sha256 mismatch), deleting...\n", __func__, old_filename.c_str()); + delete_old_path = true; + } + + std::error_code ec; + + if (delete_old_path) { + fs::remove(old_path, ec); + fs::remove(etag_path, ec); + return true; + } + + fs::path new_path(file_info->local_path); + fs::create_directories(new_path.parent_path(), ec); + + if (!fs::exists(new_path, ec)) { + fs::rename(old_path, new_path, ec); + if (ec) { + fs::copy_file(old_path, new_path, ec); + if (ec) { + LOG_WRN("%s: failed to move/copy %s: %s\n", __func__, old_path.string().c_str(), ec.message().c_str()); + return false; + } + } + fs::remove(old_path, ec); + } + fs::remove(etag_path, ec); + + std::string filename = finalize_file(*file_info); + LOG_INF("%s: migrated %s -> %s\n", __func__, old_filename.c_str(), filename.c_str()); + + return true; +} + +void migrate_old_cache_to_hf_cache(const std::string & token, bool offline) { + fs::path old_cache = fs_get_cache_directory(); + if (!fs::exists(old_cache)) { + return; + } + + if (offline) { + LOG_WRN("%s: skipping migration in offline mode (will run when online)\n", __func__); + return; // -hf is not going to work + } + + bool warned = false; + + for (const auto & entry : fs::directory_iterator(old_cache)) { + if (!entry.is_regular_file()) { + continue; + } + auto filename = entry.path().filename().string(); + auto [owner, repo] = parse_manifest_name(filename); + + if (owner.empty() || repo.empty()) { + continue; + } + + if (!warned) { + warned = true; + LOG_WRN("================================================================================\n" + "WARNING: Migrating cache to HuggingFace cache directory\n" + " Old cache: %s\n" + " New cache: %s\n" + "This one-time migration moves models previously downloaded with -hf\n" + "from the legacy llama.cpp cache to the standard HuggingFace cache.\n" + "Models downloaded with --model-url are not affected.\n" + "================================================================================\n", + old_cache.string().c_str(), get_cache_directory().string().c_str()); + } + + auto repo_id = owner + "/" + repo; + auto files = get_repo_files(repo_id, token); + + if (files.empty()) { + LOG_WRN("%s: could not get repo files for %s, skipping\n", __func__, repo_id.c_str()); + continue; + } + + try { + std::ifstream manifest(entry.path()); + auto json = nl::json::parse(manifest); + + for (const char * key : {"ggufFile", "mmprojFile"}) { + if (json.contains(key)) { + migrate_single_file(old_cache, owner, repo, json[key], files); + } + } + } catch (const std::exception & e) { + LOG_WRN("%s: failed to parse manifest %s: %s\n", __func__, filename.c_str(), e.what()); + continue; + } + fs::remove(entry.path()); + } +} + +} // namespace hf_cache diff --git a/common/hf-cache.h b/common/hf-cache.h new file mode 100644 index 000000000..ee2e98494 --- /dev/null +++ b/common/hf-cache.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include + +// Ref: https://huggingface.co/docs/hub/local-cache.md + +namespace hf_cache { + +struct hf_file { + std::string path; + std::string url; + std::string local_path; + std::string final_path; + std::string oid; + std::string repo_id; +}; + +using hf_files = std::vector; + +// Get files from HF API +hf_files get_repo_files( + const std::string & repo_id, + const std::string & token +); + +hf_files get_cached_files(const std::string & repo_id = {}); + +// Create snapshot path (link or move/copy) and return it +std::string finalize_file(const hf_file & file); + +// TODO: Remove later +void migrate_old_cache_to_hf_cache(const std::string & token, bool offline = false); + +} // namespace hf_cache diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 9162342ee..89539bd76 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -246,6 +246,10 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break; case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break; case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break; + case GGML_UNARY_OP_FLOOR: op_num = OP_UNARY_NUM_FLOOR; break; + case GGML_UNARY_OP_CEIL: op_num = OP_UNARY_NUM_CEIL; break; + case GGML_UNARY_OP_ROUND: op_num = OP_UNARY_NUM_ROUND; break; + case GGML_UNARY_OP_TRUNC: op_num = OP_UNARY_NUM_TRUNC; break; default: GGML_ABORT("fatal error"); } break; default: GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 4abd26357..b93316ce6 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1045,6 +1045,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SOFTPLUS: case GGML_UNARY_OP_EXPM1: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_TRUNC: return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); default: return false; @@ -1154,6 +1158,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te op->src[0]->ne[0] != 192 && op->src[0]->ne[0] != 256 && op->src[0]->ne[0] != 320 && + op->src[0]->ne[0] != 512 && op->src[0]->ne[0] != 576) { return false; } diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index ea471090c..eb2253e02 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -120,6 +120,10 @@ #define OP_UNARY_NUM_EXP 114 #define OP_UNARY_NUM_SOFTPLUS 115 #define OP_UNARY_NUM_EXPM1 116 +#define OP_UNARY_NUM_FLOOR 117 +#define OP_UNARY_NUM_CEIL 118 +#define OP_UNARY_NUM_ROUND 119 +#define OP_UNARY_NUM_TRUNC 120 #define OP_SUM_ROWS_NUM_SUM_ROWS 10 #define OP_SUM_ROWS_NUM_MEAN 11 diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 9c6b1c4f6..207421159 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1094,6 +1094,22 @@ kernel void kernel_unary_impl( // TODO: precise implementation dst_ptr[i0] = (T) (exp(x) - 1); } + + if (FC_OP == OP_UNARY_NUM_FLOOR) { + dst_ptr[i0] = (T) floor(x); + } + + if (FC_OP == OP_UNARY_NUM_CEIL) { + dst_ptr[i0] = (T) ceil(x); + } + + if (FC_OP == OP_UNARY_NUM_ROUND) { + dst_ptr[i0] = (T) round(x); + } + + if (FC_OP == OP_UNARY_NUM_TRUNC) { + dst_ptr[i0] = (T) trunc(x); + } } #undef FC_OP @@ -6269,6 +6285,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6284,6 +6301,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #if defined(GGML_METAL_HAS_BF16) @@ -6300,6 +6318,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif @@ -6316,6 +6335,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6331,6 +6351,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6346,6 +6367,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6361,6 +6383,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6376,6 +6399,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES @@ -6957,6 +6981,17 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk320_dv256")]] kernel flas template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 84dc6d8f1..7019766b8 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -2564,7 +2564,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, {LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, {LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}}, + {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // do the norms on the first layer (not the input layer) {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, @@ -2725,7 +2725,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // this tensor is loaded for T5, but never used {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, - {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}}, + {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}}, {LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_POS_NET_NORM2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 066725089..0136665d2 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -350,14 +350,6 @@ llama_context::llama_context( if (cparams.pipeline_parallel) { LLAMA_LOG_INFO("%s: pipeline parallelism enabled\n", __func__); - - if (!graph_reuse_disable) { - // TODO: figure out a way to make graph reuse work with pipeline parallelism - // ref: https://github.com/ggml-org/llama.cpp/pull/20463 - LLAMA_LOG_WARN("%s: graph reuse is currently not compatible with pipeline parallelism - disabling\n", __func__); - - graph_reuse_disable = true; - } } sched_reserve(); @@ -1199,6 +1191,13 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll if (!graph_reuse_disable && res->can_reuse(gparams)) { //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); + // with pipeline parallelism, the previous graph_compute_async may still be running + // on the GPU. we must synchronize before set_inputs to avoid overwriting input tensors + // that the previous compute is still reading. + if (cparams.pipeline_parallel) { + ggml_backend_sched_synchronize(sched.get()); + } + n_reused++; } else { res->reset(); diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 9da4cc9b5..1724514ea 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -928,11 +928,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell llama_seq_id seq_id; io.read_to(&seq_id, sizeof(seq_id)); - // TODO: llama_memory_recurrent should have a notion of max sequences - //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { - if (seq_id < 0) { - //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); - LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id); + if (seq_id < 0 || (uint32_t) seq_id >= this->n_seq_max) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, this->n_seq_max); return false; } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 5dc7794c3..b07c1540e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3375,8 +3375,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); } - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -3423,7 +3423,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_MODERN_BERT: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); @@ -3506,8 +3506,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); // token_type_embeddings - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); // LayerNorm - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); //LayerNorm bias + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); // LayerNorm + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); // LayerNorm bias cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED); cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {1}, TENSOR_NOT_REQUIRED); @@ -3558,8 +3558,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_BLOOM: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); @@ -5938,8 +5938,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // Block 0, LN0 - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); @@ -6053,8 +6053,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // Block 0, LN0 - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); @@ -6225,8 +6225,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd, n_vocab}, 0); - conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd, hparams.posnet.n_embd}, 0); - conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"), {1, hparams.posnet.n_embd}, 0); + conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight", 0), {7, hparams.n_embd, hparams.posnet.n_embd}, 0); + conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias", 0), {1, hparams.posnet.n_embd}, 0); // posnet { @@ -6291,8 +6291,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd); - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {hparams.posnet.n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {hparams.posnet.n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {hparams.posnet.n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {hparams.posnet.n_embd}, 0); // convnext { diff --git a/src/llama.cpp b/src/llama.cpp index 9793bfe80..3f652342b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -388,14 +388,14 @@ static void llama_params_fit_impl( case LAYER_FRACTION_ATTN: { static std::array patterns; if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(up|gate|down).*"; + patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(gate|up|gate_up|down).*"; } return patterns[il].c_str(); } case LAYER_FRACTION_UP: { static std::array patterns; if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(gate|down).*"; + patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(gate|gate_up|down).*"; } return patterns[il].c_str(); } @@ -409,7 +409,7 @@ static void llama_params_fit_impl( case LAYER_FRACTION_MOE: { static std::array patterns; if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(up|down|gate)_(ch|)exps"; + patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(up|down|gate_up|gate)_(ch|)exps"; } return patterns[il].c_str(); } @@ -503,7 +503,7 @@ static void llama_params_fit_impl( int64_t global_surplus_cpu_moe = 0; if (hp_nex > 0) { - const static std::string pattern_moe_all = "blk\\.\\d+\\.ffn_(up|down|gate)_(ch|)exps"; // matches all MoE tensors + const static std::string pattern_moe_all = "blk\\.\\d+\\.ffn_(up|down|gate_up|gate)_(ch|)exps"; // matches all MoE tensors ggml_backend_buffer_type_t cpu_buft = ggml_backend_cpu_buffer_type(); tensor_buft_overrides[0] = {pattern_moe_all.c_str(), cpu_buft}; tensor_buft_overrides[1] = {nullptr, nullptr}; diff --git a/src/models/bert.cpp b/src/models/bert.cpp index 873317914..6ab8c1368 100644 --- a/src/models/bert.cpp +++ b/src/models/bert.cpp @@ -28,8 +28,8 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params cb(inpL, "inp_embd", -1); // embed layer norm - inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); - cb(inpL, "inp_norm", -1); + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, 0); + cb(inpL, "inp_norm", 0); auto * inp_attn = build_attn_inp_no_cache(); diff --git a/src/models/bloom.cpp b/src/models/bloom.cpp index b1c19bb58..aa4b939b7 100644 --- a/src/models/bloom.cpp +++ b/src/models/bloom.cpp @@ -16,8 +16,8 @@ llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_para inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, - LLM_NORM, -1); - cb(inpL, "inp_norm", -1); + LLM_NORM, 0); + cb(inpL, "inp_norm", 0); ggml_tensor * inp_out_ids = build_inp_out_ids(); diff --git a/src/models/modern-bert.cpp b/src/models/modern-bert.cpp index 26020584c..766232109 100644 --- a/src/models/modern-bert.cpp +++ b/src/models/modern-bert.cpp @@ -15,8 +15,8 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const ll cb(inpL, "inp_embd", -1); // embed layer norm - inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1); - cb(inpL, "inp_norm", -1); + inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, 0); + cb(inpL, "inp_norm", 0); ggml_tensor * inp_out_ids = build_inp_out_ids(); diff --git a/src/models/rwkv6.cpp b/src/models/rwkv6.cpp index 15453fbf5..032b219d6 100644 --- a/src/models/rwkv6.cpp +++ b/src/models/rwkv6.cpp @@ -8,7 +8,7 @@ llm_build_rwkv6::llm_build_rwkv6(const llama_model & model, const llm_graph_para ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); - inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, 0); auto * rs_inp = build_rs_inp(); diff --git a/src/models/rwkv7.cpp b/src/models/rwkv7.cpp index 5caf6553d..16ffa6901 100644 --- a/src/models/rwkv7.cpp +++ b/src/models/rwkv7.cpp @@ -9,7 +9,7 @@ llm_build_rwkv7::llm_build_rwkv7(const llama_model & model, const llm_graph_para ggml_tensor * v_first = nullptr; inpL = build_inp_embd(model.tok_embd); - inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, 0); auto * rs_inp = build_rs_inp(); diff --git a/src/models/wavtokenizer-dec.cpp b/src/models/wavtokenizer-dec.cpp index 537a0d412..a7776d9cd 100644 --- a/src/models/wavtokenizer-dec.cpp +++ b/src/models/wavtokenizer-dec.cpp @@ -93,7 +93,7 @@ llm_build_wavtokenizer_dec::llm_build_wavtokenizer_dec(const llama_model & model cur = build_norm(cur, model.tok_norm, model.tok_norm_b, - LLM_NORM, -1); + LLM_NORM, 0); cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index c4410bf24..0144abda4 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/tests/unit/test_router.py b/tools/server/tests/unit/test_router.py index e85f2c338..717007a44 100644 --- a/tools/server/tests/unit/test_router.py +++ b/tools/server/tests/unit/test_router.py @@ -103,8 +103,8 @@ def test_router_models_max_evicts_lru(): candidate_models = [ "ggml-org/tinygemma3-GGUF:Q8_0", - "ggml-org/test-model-stories260K", - "ggml-org/test-model-stories260K-infill", + "ggml-org/test-model-stories260K:F32", + "ggml-org/test-model-stories260K-infill:F32", ] # Load only the first 2 models to fill the cache diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormTextarea.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormTextarea.svelte index f0855b9db..86c182acd 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormTextarea.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormTextarea.svelte @@ -26,6 +26,7 @@ onMount(() => { if (textareaElement) { + autoResizeTextarea(textareaElement); textareaElement.focus(); } }); @@ -50,8 +51,9 @@