diff --git a/common/arg.cpp b/common/arg.cpp index 584e0a7c7..04a01ae35 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -243,7 +243,56 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma } // download one single file from remote URL to local path -static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token) { +static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, bool offline) { + // Check if the file already exists locally + auto file_exists = std::filesystem::exists(path); + + // If the file exists, check its JSON metadata companion file. + std::string metadata_path = path + ".json"; + nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead + std::string etag; + std::string last_modified; + + if (file_exists) { + if (offline) { + LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str()); + return true; // skip verification/downloading + } + // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block). + std::ifstream metadata_in(metadata_path); + if (metadata_in.good()) { + try { + metadata_in >> metadata; + LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str()); + if (metadata.contains("etag") && metadata.at("etag").is_string()) { + etag = metadata.at("etag"); + } + if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) { + last_modified = metadata.at("lastModified"); + } + } catch (const nlohmann::json::exception & e) { + LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); + } + } + // if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again) + } else { + if (offline) { + LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str()); + return false; + } + LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); + } + + // Send a HEAD request to retrieve the etag and last-modified headers + struct common_load_model_from_url_headers { + std::string etag; + std::string last_modified; + }; + + common_load_model_from_url_headers headers; + bool head_request_ok = false; + bool should_download = !file_exists; // by default, we should download if the file does not exist + // Initialize libcurl curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); curl_slist_ptr http_headers; @@ -270,91 +319,47 @@ static bool common_download_file_single(const std::string & url, const std::stri curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); #endif - // Check if the file already exists locally - auto file_exists = std::filesystem::exists(path); + typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *); + auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t { + common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata; - // If the file exists, check its JSON metadata companion file. - std::string metadata_path = path + ".json"; - nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead - std::string etag; - std::string last_modified; + static std::regex header_regex("([^:]+): (.*)\r\n"); + static std::regex etag_regex("ETag", std::regex_constants::icase); + static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase); - if (file_exists) { - // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block). - std::ifstream metadata_in(metadata_path); - if (metadata_in.good()) { - try { - metadata_in >> metadata; - LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str()); - if (metadata.contains("etag") && metadata.at("etag").is_string()) { - etag = metadata.at("etag"); - } - if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) { - last_modified = metadata.at("lastModified"); - } - } catch (const nlohmann::json::exception & e) { - LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); + std::string header(buffer, n_items); + std::smatch match; + if (std::regex_match(header, match, header_regex)) { + const std::string & key = match[1]; + const std::string & value = match[2]; + if (std::regex_match(key, match, etag_regex)) { + headers->etag = value; + } else if (std::regex_match(key, match, last_modified_regex)) { + headers->last_modified = value; } } - // if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again) - } else { - LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); - } - - // Send a HEAD request to retrieve the etag and last-modified headers - struct common_load_model_from_url_headers { - std::string etag; - std::string last_modified; + return n_items; }; - common_load_model_from_url_headers headers; - bool head_request_ok = false; - bool should_download = !file_exists; // by default, we should download if the file does not exist + curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress + curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast(header_callback)); + curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); - // get ETag to see if the remote file has changed - { - typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *); - auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t { - common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata; + // we only allow retrying once for HEAD requests + // this is for the use case of using running offline (no internet), retrying can be annoying + bool was_perform_successful = curl_perform_with_retry(url, curl.get(), 1, 0, "HEAD"); + if (!was_perform_successful) { + head_request_ok = false; + } - static std::regex header_regex("([^:]+): (.*)\r\n"); - static std::regex etag_regex("ETag", std::regex_constants::icase); - static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase); - - std::string header(buffer, n_items); - std::smatch match; - if (std::regex_match(header, match, header_regex)) { - const std::string & key = match[1]; - const std::string & value = match[2]; - if (std::regex_match(key, match, etag_regex)) { - headers->etag = value; - } else if (std::regex_match(key, match, last_modified_regex)) { - headers->last_modified = value; - } - } - return n_items; - }; - - curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb - curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress - curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast(header_callback)); - curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); - - // we only allow retrying once for HEAD requests - // this is for the use case of using running offline (no internet), retrying can be annoying - bool was_perform_successful = curl_perform_with_retry(url, curl.get(), 1, 0, "HEAD"); - if (!was_perform_successful) { - head_request_ok = false; - } - - long http_code = 0; - curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); - if (http_code == 200) { - head_request_ok = true; - } else { - LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code); - head_request_ok = false; - } + long http_code = 0; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); + if (http_code == 200) { + head_request_ok = true; + } else { + LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code); + head_request_ok = false; } // if head_request_ok is false, we don't have the etag or last-modified headers @@ -461,12 +466,12 @@ static bool common_download_file_single(const std::string & url, const std::stri // download multiple files from remote URLs to local paths // the input is a vector of pairs -static bool common_download_file_multiple(const std::vector> & urls, const std::string & bearer_token) { +static bool common_download_file_multiple(const std::vector> & urls, const std::string & bearer_token, bool offline) { // Prepare download in parallel std::vector> futures_download; for (auto const & item : urls) { - futures_download.push_back(std::async(std::launch::async, [bearer_token](const std::pair & it) -> bool { - return common_download_file_single(it.first, it.second, bearer_token); + futures_download.push_back(std::async(std::launch::async, [bearer_token, offline](const std::pair & it) -> bool { + return common_download_file_single(it.first, it.second, bearer_token, offline); }, item)); } @@ -482,14 +487,15 @@ static bool common_download_file_multiple(const std::vector> common_remote_get_content(const std::string & * * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files. */ -static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token) { +static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token, bool offline) { auto parts = string_split(hf_repo_with_tag, ':'); std::string tag = parts.size() > 1 ? parts.back() : "latest"; std::string hf_repo = parts[0]; @@ -639,20 +645,25 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_ long res_code = 0; std::string res_str; bool use_cache = false; - try { - auto res = common_remote_get_content(url, params); - res_code = res.first; - res_str = std::string(res.second.data(), res.second.size()); - } catch (const std::exception & e) { - LOG_WRN("error: failed to get manifest: %s\n", e.what()); - LOG_WRN("try reading from cache\n"); - // try to read from cache + if (!offline) { try { + auto res = common_remote_get_content(url, params); + res_code = res.first; + res_str = std::string(res.second.data(), res.second.size()); + } catch (const std::exception & e) { + LOG_WRN("error: failed to get manifest at %s: %s\n", url.c_str(), e.what()); + } + } + if (res_code == 0) { + if (std::filesystem::exists(cached_response_path)) { + LOG_WRN("trying to read manifest from cache: %s\n", cached_response_path.c_str()); res_str = read_file(cached_response_path); res_code = 200; use_cache = true; - } catch (const std::exception & e) { - throw std::runtime_error("error: failed to get manifest (check your internet connection)"); + } else { + throw std::runtime_error( + offline ? "error: failed to get manifest (offline mode)" + : "error: failed to get manifest (check your internet connection)"); } } std::string ggufFile; @@ -699,24 +710,25 @@ bool common_has_curl() { return false; } -static bool common_download_file_single(const std::string &, const std::string &, const std::string &) { +static bool common_download_file_single(const std::string &, const std::string &, const std::string &, bool) { LOG_ERR("error: built without CURL, cannot download model from internet\n"); return false; } -static bool common_download_file_multiple(const std::vector> &, const std::string &) { +static bool common_download_file_multiple(const std::vector> &, const std::string &, bool) { LOG_ERR("error: built without CURL, cannot download model from the internet\n"); return false; } static bool common_download_model( const common_params_model &, - const std::string &) { + const std::string &, + bool) { LOG_ERR("error: built without CURL, cannot download model from the internet\n"); return false; } -static struct common_hf_file_res common_get_hf_file(const std::string &, const std::string &) { +static struct common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool) { LOG_ERR("error: built without CURL, cannot download model from the internet\n"); return {}; } @@ -743,7 +755,8 @@ struct handle_model_result { static handle_model_result common_params_handle_model( struct common_params_model & model, const std::string & bearer_token, - const std::string & model_path_default) { + const std::string & model_path_default, + bool offline) { handle_model_result result; // handle pre-fill default model path and url based on hf_repo and hf_file { @@ -751,7 +764,7 @@ static handle_model_result common_params_handle_model( // short-hand to avoid specifying --hf-file -> default it to --model if (model.hf_file.empty()) { if (model.path.empty()) { - auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token); + auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token, offline); if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) { exit(1); // built without CURL, error message already printed } @@ -792,7 +805,7 @@ static handle_model_result common_params_handle_model( // then, download it if needed if (!model.url.empty()) { - bool ok = common_download_model(model, bearer_token); + bool ok = common_download_model(model, bearer_token, offline); if (!ok) { LOG_ERR("error: failed to download model from %s\n", model.url.c_str()); exit(1); @@ -935,7 +948,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context // handle model and download { - auto res = common_params_handle_model(params.model, params.hf_token, DEFAULT_MODEL_PATH); + auto res = common_params_handle_model(params.model, params.hf_token, DEFAULT_MODEL_PATH, params.offline); if (params.no_mmproj) { params.mmproj = {}; } else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) { @@ -945,12 +958,12 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context // only download mmproj if the current example is using it for (auto & ex : mmproj_examples) { if (ctx_arg.ex == ex) { - common_params_handle_model(params.mmproj, params.hf_token, ""); + common_params_handle_model(params.mmproj, params.hf_token, "", params.offline); break; } } - common_params_handle_model(params.speculative.model, params.hf_token, ""); - common_params_handle_model(params.vocoder.model, params.hf_token, ""); + common_params_handle_model(params.speculative.model, params.hf_token, "", params.offline); + common_params_handle_model(params.vocoder.model, params.hf_token, "", params.offline); } if (params.escape) { @@ -2997,6 +3010,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex common_log_set_verbosity_thold(INT_MAX); } )); + add_opt(common_arg( + {"--offline"}, + "Offline mode: forces use of cache, prevents network access", + [](common_params & params) { + params.offline = true; + } + ).set_env("LLAMA_OFFLINE")); add_opt(common_arg( {"-lv", "--verbosity", "--log-verbosity"}, "N", "Set the verbosity threshold. Messages with a higher verbosity will be ignored.", diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index 54475683b..c314b8b51 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -170,20 +170,23 @@ std::string common_chat_msg_parser::consume_rest() { } // Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback. -std::optional common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from) { +std::optional common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from, bool add_prelude_to_content) { auto m = regex.search(input_, from == std::string::npos ? pos_ : from); if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) { return std::nullopt; } + auto prelude = input_.substr(pos_, m.groups[0].begin - pos_); + pos_ = m.groups[0].end; + + if (add_prelude_to_content) { + add_content(prelude); + } if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) { if (is_partial()) { throw common_chat_msg_partial_exception(regex.str()); } return std::nullopt; } - auto prelude = input_.substr(pos_, m.groups[0].begin - pos_); - pos_ = m.groups[0].end; - return find_regex_result{prelude, m.groups}; } diff --git a/common/chat-parser.h b/common/chat-parser.h index b21b32b8a..5d53f2df1 100644 --- a/common/chat-parser.h +++ b/common/chat-parser.h @@ -30,6 +30,7 @@ class common_chat_msg_parser { const std::string & healing_marker() const { return healing_marker_; } const bool & is_partial() const { return is_partial_; } const common_chat_msg & result() const { return result_; } + const common_chat_syntax & syntax() const { return syntax_; } void move_to(size_t pos) { if (pos > input_.size()) { @@ -77,7 +78,7 @@ class common_chat_msg_parser { std::vector groups; }; - std::optional try_find_regex(const common_regex & regex, size_t from = std::string::npos); + std::optional try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true); bool try_consume_literal(const std::string & literal); diff --git a/common/chat.cpp b/common/chat.cpp index 9779dc118..5b27c7593 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -31,6 +31,11 @@ static std::string string_diff(const std::string & last, const std::string & cur return current; } if (!string_starts_with(current, last)) { + if (string_starts_with(last, current)) { + // This happens if the last generation ended on a partial stop word (not erased), + // and the current ended on a stop word (erased). + return ""; + } throw std::runtime_error("Invalid diff: '" + last + "' not found at start of '" + current + "'"); } return current.substr(last.size()); @@ -101,9 +106,9 @@ std::vector common_chat_msg_diff::compute_diffs(const comm if (!args_diff.empty() || pref.id != newf.id) { auto & diff = diffs.emplace_back(); diff.tool_call_index = idx; - diff.tool_call_delta.name = newf.name; if (pref.id != newf.id) { diff.tool_call_delta.id = newf.id; + diff.tool_call_delta.name = newf.name; } diff.tool_call_delta.arguments = args_diff; } @@ -387,22 +392,19 @@ template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_di delta["content"] = diff.content_delta; } if (diff.tool_call_index != std::string::npos) { + json tool_call; + tool_call["index"] = diff.tool_call_index; + if (!diff.tool_call_delta.id.empty()) { + tool_call["id"] = diff.tool_call_delta.id; + tool_call["type"] = "function"; + } json function = json::object(); if (!diff.tool_call_delta.name.empty()) { function["name"] = diff.tool_call_delta.name; } - if (!diff.tool_call_delta.id.empty()) { - function["id"] = diff.tool_call_delta.id; - } - if (!diff.tool_call_delta.arguments.empty()) { - function["arguments"] = diff.tool_call_delta.arguments; - } - delta["tool_calls"] = json::array({ - json { - {"index", diff.tool_call_index}, - {"function", function} - } - }); + function["arguments"] = diff.tool_call_delta.arguments; + tool_call["function"] = function; + delta["tool_calls"] = json::array({tool_call}); } return delta; } @@ -654,7 +656,6 @@ static void parse_json_tool_calls( } from = std::string::npos; - builder.add_content(res->prelude); auto maybe_raw_python = name == "python" && allow_raw_python; if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) { if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { @@ -684,7 +685,6 @@ static void parse_json_tool_calls( }; if (block_open) { if (auto res = builder.try_find_regex(*block_open)) { - builder.add_content(res->prelude); parse_tool_calls(); } else { builder.add_content(builder.consume_rest()); @@ -697,7 +697,6 @@ static void parse_json_tool_calls( static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, const common_regex & prefix, size_t rstrip_prefix = 0) { static const std::vector> args_paths = {{"arguments"}}; if (auto res = builder.try_find_regex(prefix)) { - builder.add_content(res->prelude); builder.move_back(rstrip_prefix); auto tool_calls = builder.consume_json_with_dumped_args(args_paths); if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) { @@ -833,6 +832,10 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp return data; } static void common_chat_parse_generic(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } static const std::vector> content_paths = { {"response"}, }; @@ -905,6 +908,11 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat return data; } static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + static const common_regex prefix(regex_escape("[TOOL_CALLS]")); parse_prefixed_json_tool_call_array(builder, prefix); } @@ -999,7 +1007,6 @@ static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) { if (auto res = builder.try_find_regex(start_action_regex)) { // If we didn't extract thoughts, prelude includes them. - builder.add_content(res->prelude); auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}}); for (const auto & tool_call : tool_calls.value) { std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; @@ -1014,11 +1021,7 @@ static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) { } builder.consume_regex(end_action_regex); } else if (auto res = builder.try_find_regex(start_response_regex)) { - // If we didn't extract thoughts, prelude includes them. - builder.add_content(res->prelude); - if (auto res = builder.try_find_regex(end_response_regex)) { - builder.add_content(res->prelude); - } else { + if (!builder.try_find_regex(end_response_regex)) { builder.add_content(builder.consume_rest()); throw common_chat_msg_partial_exception(end_response_regex.str()); } @@ -1126,6 +1129,11 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te return data; } static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + static const common_regex function_regex( "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: "); static const common_regex close_regex("\\}\\s*"); @@ -1136,8 +1144,6 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w if (with_builtin_tools) { static const common_regex builtin_call_regex("<\\|python_tag\\|>"); if (auto res = builder.try_find_regex(builtin_call_regex)) { - builder.add_content(res->prelude); - auto fun_res = builder.consume_regex(function_name_regex); auto function_name = builder.str(fun_res.groups[1]); @@ -1253,6 +1259,10 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_ } static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); static const common_regex tool_calls_end("<|tool▁calls▁end|>"); @@ -1314,6 +1324,10 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c return data; } static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } static const common_regex prefix(regex_escape(" functools[")); parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1); } @@ -1455,15 +1469,12 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con return data; } static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) { - // This version of Functionary still supports the llama 3.1 tool call format for the python tool. - static const common_regex python_tag_regex(regex_escape("<|python_tag|>")); - - if (auto res = builder.try_find_regex(python_tag_regex)) { - builder.add_content(res->prelude); - auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); - builder.add_tool_call("python", "", arguments); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); return; } + // This version of Functionary still supports the llama 3.1 tool call format for the python tool. + static const common_regex python_tag_regex(regex_escape("<|python_tag|>")); static const common_regex function_regex(R"()"); static const common_regex close_regex(R"()"); @@ -1475,6 +1486,12 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser function_regex, close_regex, std::nullopt); + + if (auto res = builder.try_find_regex(python_tag_regex)) { + auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); + builder.add_tool_call("python", "", arguments); + return; + } } static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) { @@ -1593,6 +1610,10 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat } static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } static const common_regex open_regex( "(?:" @@ -1614,8 +1635,6 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { ); if (auto res = builder.try_find_regex(open_regex)) { - builder.add_content(res->prelude); - const auto & block_start = res->groups[1]; std::string block_end = block_start.empty() ? "" : "```"; @@ -1851,10 +1870,10 @@ static void common_chat_parse_content_only(common_chat_msg_parser & builder) { builder.add_content(builder.consume_rest()); } -static void common_chat_parse(common_chat_msg_parser & builder, common_chat_format format) { - LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(format), builder.input().c_str()); +static void common_chat_parse(common_chat_msg_parser & builder) { + LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str()); - switch (format) { + switch (builder.syntax().format) { case COMMON_CHAT_FORMAT_CONTENT_ONLY: common_chat_parse_content_only(builder); break; @@ -1889,7 +1908,7 @@ static void common_chat_parse(common_chat_msg_parser & builder, common_chat_form common_chat_parse_command_r7b(builder); break; default: - throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(format)); + throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); } builder.finish(); } @@ -1897,7 +1916,7 @@ static void common_chat_parse(common_chat_msg_parser & builder, common_chat_form common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) { common_chat_msg_parser builder(input, is_partial, syntax); try { - common_chat_parse(builder, syntax.format); + common_chat_parse(builder); } catch (const common_chat_msg_partial_exception & ex) { LOG_DBG("Partial parse: %s\n", ex.what()); if (!is_partial) { diff --git a/common/chat.h b/common/chat.h index 3e2cbbaae..f6b1d0ffc 100644 --- a/common/chat.h +++ b/common/chat.h @@ -144,6 +144,7 @@ struct common_chat_syntax { // Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode) bool reasoning_in_content = false; bool thinking_forced_open = false; + bool parse_tool_calls = true; }; // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid diff --git a/common/common.h b/common/common.h index 8ba45a8ee..aad27f500 100644 --- a/common/common.h +++ b/common/common.h @@ -287,6 +287,7 @@ struct common_params { int32_t verbosity = 0; int32_t control_vector_layer_start = -1; // layer range for control vector int32_t control_vector_layer_end = -1; // layer range for control vector + bool offline = false; int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line diff --git a/examples/training/README.md b/examples/training/README.md index ecdf398f8..df4252792 100644 --- a/examples/training/README.md +++ b/examples/training/README.md @@ -10,8 +10,8 @@ Proof of concept: ``` sh export model_name=llama_3.2-1b && export quantization=f32 -./build/bin/finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512 -./build/bin/perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf +./build/bin/llama-finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512 +./build/bin/llama-perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf ``` The perplexity value of the finetuned model should be lower after training on the test set for 2 epochs. diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 73a7975a8..a3db16b60 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1604,6 +1604,9 @@ void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) { for (int i = 0; i < sched->n_backends; i++) { ggml_backend_synchronize(sched->backends[i]); } + // reset the current copy to 0 so that the graphs will be similar during generation + // necessary for CUDA graphs + sched->cur_copy = 0; } void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) { diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index ca66137fd..d8e312330 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -168,7 +168,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in #define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str) -#if !defined(GGML_USE_HIP) +#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) static const char * cu_get_error_str(CUresult err) { const char * err_str; cuGetErrorString(err, &err_str); diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index ca3551465..06c26c9b1 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -6476,6 +6476,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_ROPE: case GGML_OP_RMS_NORM: case GGML_OP_CONV_2D_DW: + case GGML_OP_IM2COL: return true; default: return false; diff --git a/include/llama.h b/include/llama.h index 04e060b7c..114092158 100644 --- a/include/llama.h +++ b/include/llama.h @@ -615,11 +615,11 @@ extern "C" { // Returns the number of tokens in the KV cache (slow, use only for debug) // If a KV cell has multiple sequences assigned to it, it will be counted multiple times DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx), - "Use llama_kv_self_seq_pos_max() instead"); + "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx), - "Use llama_kv_self_seq_pos_max() instead"); + "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); // Clear the KV cache - both cell info is erased and KV data is zeroed LLAMA_API void llama_kv_self_clear( diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 4c526414f..9bb23f3fc 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -693,12 +693,18 @@ int llama_context::encode(llama_batch & inp_batch) { GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + // TODO: move the validation to the llama_batch_allocr if (batch.token) { for (int32_t i = 0; i < n_tokens; ++i) { if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); return -1; } + + if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) { + LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES); + throw -1; + } } } @@ -852,7 +858,7 @@ int llama_context::encode(llama_batch & inp_batch) { int llama_context::decode(llama_batch & inp_batch) { if (!memory) { - LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__); + LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); return encode(inp_batch); } @@ -887,11 +893,17 @@ int llama_context::decode(llama_batch & inp_batch) { GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + // TODO: move the validation to the llama_batch_allocr if (batch.token) { for (int64_t i = 0; i < n_tokens_all; ++i) { if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]); - throw std::runtime_error("invalid token"); + return -1; + } + + if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) { + LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES); + return -1; } } } diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 8a2a08bc4..56c1939a2 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -286,31 +286,11 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po } llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { - llama_pos result = std::numeric_limits::max(); - - for (uint32_t i = 0; i < cells.size(); ++i) { - if (cells.seq_has(i, seq_id)) { - result = std::min(result, cells.pos_get(i)); - } - } - - if (result == std::numeric_limits::max()) { - result = -1; - } - - return result; + return cells.seq_pos_min(seq_id); } llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { - llama_pos result = -1; - - for (uint32_t i = 0; i < cells.size(); ++i) { - if (cells.seq_has(i, seq_id)) { - result = std::max(result, cells.pos_get(i)); - } - } - - return result; + return cells.seq_pos_max(seq_id); } void llama_kv_cache_unified::restore() { @@ -504,7 +484,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cell_max(), n_pad))); + n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); #ifdef FIND_SLOT_DEBUG LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa); @@ -1018,7 +998,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { const uint32_t n_layer = layers.size(); - const uint32_t n_kv = cell_max(); + const uint32_t n_kv = cells.used_max_p1(); const uint32_t n_used = cells.get_used(); assert(n_used <= n_kv); @@ -1144,16 +1124,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { return true; } -uint32_t llama_kv_cache_unified::cell_max() const { - for (uint32_t i = cells.size(); i > 0; --i) { - if (!cells.is_empty(i - 1)) { - return i; - } - } - - return 0; -} - bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { assert(p0 >= 0 && p1 >= 0); diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 86a96820e..ce6261e45 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -246,10 +246,6 @@ private: // return true if cells have been moved bool defrag_prepare(int32_t n_max_nodes); - // find how many cells are currently in use - // TODO: optimize - uint32_t cell_max() const; - size_t total_size() const; size_t size_k_bytes() const; diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index 138545533..dbbd03fcb 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -6,6 +6,7 @@ #include #include #include +#include // meta information about KV cells that can be part of multiple sequences at the same time // TODO: add unit tests @@ -18,8 +19,13 @@ public: seq[i].reset(); } - used = 0; has_shift = false; + + used.clear(); + + for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + seq_pos[s].clear(); + } } void reset_shift() { @@ -50,7 +56,25 @@ public: } uint32_t get_used() const { - return used; + return used.size(); + } + + // the index of the first cell that is used + // return 0 if no cells are used + uint32_t used_min() const { + return used.empty() ? 0 : *used.begin(); + } + + // the index of the last cell that is used + 1 + // return 0 if no cells are used + uint32_t used_max_p1() const { +#if 0 + if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin()); + if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin()); + if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin()); +#endif + + return used.empty() ? 0 : *used.rbegin() + 1; } bool get_has_shift() const { @@ -69,6 +93,9 @@ public: pos [isrc] = -1; shift[isrc] = 0; seq [isrc].reset(); + + used.erase (isrc); + used.insert(idst); } // copy the state of cells [i, i + n) (used for save/restore the state of the cells) @@ -95,16 +122,24 @@ public: for (uint32_t j = 0; j < other.pos.size(); ++j) { if (pos[i + j] == -1 && other.pos[j] != -1) { - used++; + used.insert(i + j); } if (pos[i + j] != -1 && other.pos[j] == -1) { - used--; + used.erase(i + j); + } + + if (pos[i + j] != -1) { + seq_pos_rm(i + j); } pos[i + j] = other.pos[j]; seq[i + j] = other.seq[j]; + if (pos[i + j] != -1) { + seq_pos_add(i + j); + } + assert(shift[i + j] == 0); } } @@ -118,11 +153,12 @@ public: assert(seq_id >= 0); seq[i].reset(seq_id); + seq_pos[seq_id].erase(pos[i]); if (seq[i].none()) { pos[i] = -1; - used--; + used.erase(i); return true; } @@ -135,17 +171,22 @@ public: assert(i < pos.size()); if (seq[i].test(seq_id)) { + seq_pos_rm(i); seq[i].reset(); + seq[i].set(seq_id); + seq_pos[seq_id].insert(pos[i]); return false; } if (seq[i].any()) { + seq_pos_rm(i); seq[i].reset(); + pos[i] = -1; - used--; + used.erase(i); return true; } @@ -169,6 +210,33 @@ public: assert(!seq[i].test(seq_id)); seq[i].set(seq_id); + seq_pos[seq_id].insert(pos[i]); + } + + // the minimum position of sequence seq_id currently present in any of the cells + // return -1 if the sequence is not present + llama_pos seq_pos_min(llama_seq_id seq_id) const { + assert(seq_id >= 0); + assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES); + + if (seq_pos[seq_id].empty()) { + return -1; + } + + return *seq_pos[seq_id].begin(); + } + + // the maximum position of sequence seq_id currently present in any of the cells + // return -1 if the sequence is not present + llama_pos seq_pos_max(llama_seq_id seq_id) const { + assert(seq_id >= 0); + assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES); + + if (seq_pos[seq_id].empty()) { + return -1; + } + + return *seq_pos[seq_id].rbegin(); } // note: call only if the cell is not empty @@ -202,7 +270,8 @@ public: assert(pos[i] == -1); pos[i] = p; - used++; + + used.insert(i); } // pos[i] = pos[i] + d @@ -212,16 +281,22 @@ public: assert(i < pos.size()); assert(pos[i] != -1); + seq_pos_rm(i); + pos[i] += d; shift[i] += d; + seq_pos_add(i); + has_shift = true; if (pos[i] < 0) { - pos[i] = -1; - seq[i].reset(); + seq_pos_rm(i); - used--; + seq[i].reset(); + pos[i] = -1; + + used.erase(i); return true; } @@ -238,17 +313,22 @@ public: const llama_pos p_old = pos[i]; + seq_pos_rm(i); + pos[i] /= d; shift[i] += p_old - pos[i]; + seq_pos_add(i); + has_shift = true; } private: - uint32_t used = 0; // used cells (i.e. pos[i] != -1, allowed to not have any seq_id) - bool has_shift = false; + // set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id) + std::set used; + std::vector pos; // this array accumulates any applied shifts to the pos array since the last reset_shift() call @@ -268,6 +348,32 @@ private: // std::vector shift; - std::vector> seq; -}; + using bits_t = std::bitset; + // the bitset seq[i] tells us which sequences are currently occupying the i-th cell + std::vector seq; + + // the set seq_pos[s] tells us which positions are currently present for sequence s + // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache + std::set seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES]; + + // helper functions for updating `seq_pos`, once cell at a time: + + // remove cell i + void seq_pos_rm(uint32_t i) { + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (seq[i].test(s)) { + seq_pos[s].erase(pos[i]); + } + } + } + + // add cell i + void seq_pos_add(uint32_t i) { + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (seq[i].test(s)) { + seq_pos[s].insert(pos[i]); + } + } + } +}; diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 804b11e0a..bfbf5fa23 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -798,7 +798,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d } // if we have enough values the operation was a success - if (filtered_tokens.size() >= ctx->min_keep) { + if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) { memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data)); cur_p->size = filtered_tokens.size(); min_p_applied = true; @@ -909,7 +909,7 @@ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token cum_sum += cur_p->data[idx].p; // Check if the running sum is greater than typical or if we have kept at least min_keep tokens - if (cum_sum > ctx->p && i >= ctx->min_keep - 1) { + if (cum_sum > ctx->p && (ctx->min_keep == 0 || i >= ctx->min_keep - 1)) { last_idx = i + 1; break; } diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 07b613122..fe6c685ec 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -364,6 +364,7 @@ struct server_task { params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format; params.oaicompat_chat_syntax.reasoning_in_content = params.stream; params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false); + params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); } { @@ -3394,13 +3395,7 @@ struct server_context { batch.logits + i, }; - int ret = 0; - - if (do_encode) { - ret = llama_encode(ctx, batch_view); - } else { - ret = llama_decode(ctx, batch_view); - } + const int ret = llama_decode(ctx, batch_view); metrics.on_decoded(slots); diff --git a/tools/server/tests/unit/test_completion.py b/tools/server/tests/unit/test_completion.py index 4099c4e25..f6909e9ae 100644 --- a/tools/server/tests/unit/test_completion.py +++ b/tools/server/tests/unit/test_completion.py @@ -121,6 +121,30 @@ def test_completion_stream_with_openai_library(): assert match_regex("(going|bed)+", output_text) +# Test case from https://github.com/ggml-org/llama.cpp/issues/13780 +@pytest.mark.slow +def test_completion_stream_with_openai_library_stops(): + global server + server.model_hf_repo = "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M" + server.model_hf_file = None + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.completions.create( + model="davinci-002", + prompt="System: You are helpfull assistant.\nAssistant:\nHey! How could I help?\nUser:\nTell me a joke.\nAssistant:\n", + stop=["User:\n", "Assistant:\n"], + max_tokens=200, + stream=True, + ) + output_text = '' + for data in res: + choice = data.choices[0] + if choice.finish_reason is None: + assert choice.text is not None + output_text += choice.text + assert match_regex("Sure, here's one for[\\s\\S]*", output_text), f'Unexpected output: {output_text}' + + @pytest.mark.parametrize("n_slots", [1, 2]) def test_consistent_result_same_seed(n_slots: int): global server diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index 11672f515..f7e1b3b3b 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -328,6 +328,10 @@ class ServerProcess: if 'function' not in tc: raise ValueError(f"Expected function type, got {tc['type']}") if tc['index'] >= len(tool_calls): + assert 'id' in tc + assert tc.get('type') == 'function' + assert 'function' in tc and 'name' in tc['function'] and len(tc['function']['name']) > 0, \ + f"Expected function call with name, got {tc.get('function')}" tool_calls.append(dict( id="", type="function", @@ -340,10 +344,10 @@ class ServerProcess: if tc.get('id') is not None: tool_call['id'] = tc['id'] fct = tc['function'] + assert 'id' not in fct, f"Function call should not have id: {fct}" if fct.get('name') is not None: - tool_call['function']['name'] = fct['name'] + tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name'] if fct.get('arguments') is not None: - assert len(fct['arguments']) > 0, f'Expected non empty arguments delta!' tool_call['function']['arguments'] += fct['arguments'] print(f'Streamed response had {content_parts} content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts') diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index fc9f7071e..8456a02e6 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -735,8 +735,11 @@ static json oaicompat_chat_params_parse( inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); inputs.reasoning_format = opt.reasoning_format; inputs.enable_thinking = opt.enable_thinking; - if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) { - throw std::runtime_error("Cannot use custom grammar constraints with tools."); + if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { + if (body.contains("grammar")) { + throw std::runtime_error("Cannot use custom grammar constraints with tools."); + } + llama_params["parse_tool_calls"] = true; } // if the assistant message appears at the end of list, we do not add end-of-turn token