Merge commit '72b090da2c' into concedo_experimental

# Conflicts:
#	docs/backend/CANN.md
#	docs/function-calling.md
#	examples/embedding/embedding.cpp
#	examples/retrieval/retrieval.cpp
#	ggml/src/ggml-cann/CMakeLists.txt
#	ggml/src/ggml-cann/Doxyfile
#	ggml/src/ggml-cann/acl_tensor.cpp
#	ggml/src/ggml-cann/acl_tensor.h
#	ggml/src/ggml-cann/aclnn_ops.cpp
#	ggml/src/ggml-cann/aclnn_ops.h
#	ggml/src/ggml-cann/common.h
#	ggml/src/ggml-cann/ggml-cann.cpp
#	ggml/src/ggml-cpu/CMakeLists.txt
#	ggml/src/ggml-sycl/binbcast.cpp
#	ggml/src/ggml-sycl/common.hpp
#	ggml/src/ggml-sycl/concat.cpp
#	ggml/src/ggml-sycl/conv.cpp
#	ggml/src/ggml-sycl/cpy.cpp
#	ggml/src/ggml-sycl/dmmv.cpp
#	ggml/src/ggml-sycl/element_wise.cpp
#	ggml/src/ggml-sycl/getrows.cpp
#	ggml/src/ggml-sycl/ggml-sycl.cpp
#	ggml/src/ggml-sycl/gla.cpp
#	ggml/src/ggml-sycl/mmvq.cpp
#	ggml/src/ggml-sycl/norm.cpp
#	ggml/src/ggml-sycl/outprod.cpp
#	ggml/src/ggml-sycl/rope.cpp
#	ggml/src/ggml-sycl/softmax.cpp
#	ggml/src/ggml-sycl/tsembd.cpp
#	ggml/src/ggml-sycl/wkv.cpp
#	scripts/compare-commits.sh
#	tests/test-chat.cpp
#	tests/test-sampling.cpp
This commit is contained in:
Concedo 2025-05-28 00:28:41 +08:00
commit 8c701d7ded
20 changed files with 380 additions and 221 deletions

View file

@ -243,7 +243,56 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma
}
// download one single file from remote URL to local path
static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token) {
static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, bool offline) {
// Check if the file already exists locally
auto file_exists = std::filesystem::exists(path);
// If the file exists, check its JSON metadata companion file.
std::string metadata_path = path + ".json";
nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead
std::string etag;
std::string last_modified;
if (file_exists) {
if (offline) {
LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
return true; // skip verification/downloading
}
// Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
std::ifstream metadata_in(metadata_path);
if (metadata_in.good()) {
try {
metadata_in >> metadata;
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
if (metadata.contains("etag") && metadata.at("etag").is_string()) {
etag = metadata.at("etag");
}
if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
last_modified = metadata.at("lastModified");
}
} catch (const nlohmann::json::exception & e) {
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
}
}
// if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again)
} else {
if (offline) {
LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str());
return false;
}
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
}
// Send a HEAD request to retrieve the etag and last-modified headers
struct common_load_model_from_url_headers {
std::string etag;
std::string last_modified;
};
common_load_model_from_url_headers headers;
bool head_request_ok = false;
bool should_download = !file_exists; // by default, we should download if the file does not exist
// Initialize libcurl
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
curl_slist_ptr http_headers;
@ -270,49 +319,6 @@ static bool common_download_file_single(const std::string & url, const std::stri
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
#endif
// Check if the file already exists locally
auto file_exists = std::filesystem::exists(path);
// If the file exists, check its JSON metadata companion file.
std::string metadata_path = path + ".json";
nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead
std::string etag;
std::string last_modified;
if (file_exists) {
// Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
std::ifstream metadata_in(metadata_path);
if (metadata_in.good()) {
try {
metadata_in >> metadata;
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
if (metadata.contains("etag") && metadata.at("etag").is_string()) {
etag = metadata.at("etag");
}
if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
last_modified = metadata.at("lastModified");
}
} catch (const nlohmann::json::exception & e) {
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
}
}
// if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again)
} else {
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
}
// Send a HEAD request to retrieve the etag and last-modified headers
struct common_load_model_from_url_headers {
std::string etag;
std::string last_modified;
};
common_load_model_from_url_headers headers;
bool head_request_ok = false;
bool should_download = !file_exists; // by default, we should download if the file does not exist
// get ETag to see if the remote file has changed
{
typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
@ -355,7 +361,6 @@ static bool common_download_file_single(const std::string & url, const std::stri
LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
head_request_ok = false;
}
}
// if head_request_ok is false, we don't have the etag or last-modified headers
// we leave should_download as-is, which is true if the file does not exist
@ -461,12 +466,12 @@ static bool common_download_file_single(const std::string & url, const std::stri
// download multiple files from remote URLs to local paths
// the input is a vector of pairs <url, path>
static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls, const std::string & bearer_token) {
static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls, const std::string & bearer_token, bool offline) {
// Prepare download in parallel
std::vector<std::future<bool>> futures_download;
for (auto const & item : urls) {
futures_download.push_back(std::async(std::launch::async, [bearer_token](const std::pair<std::string, std::string> & it) -> bool {
return common_download_file_single(it.first, it.second, bearer_token);
futures_download.push_back(std::async(std::launch::async, [bearer_token, offline](const std::pair<std::string, std::string> & it) -> bool {
return common_download_file_single(it.first, it.second, bearer_token, offline);
}, item));
}
@ -482,14 +487,15 @@ static bool common_download_file_multiple(const std::vector<std::pair<std::strin
static bool common_download_model(
const common_params_model & model,
const std::string & bearer_token) {
const std::string & bearer_token,
bool offline) {
// Basic validation of the model.url
if (model.url.empty()) {
LOG_ERR("%s: invalid model url\n", __func__);
return false;
}
if (!common_download_file_single(model.url, model.path, bearer_token)) {
if (!common_download_file_single(model.url, model.path, bearer_token, offline)) {
return false;
}
@ -548,7 +554,7 @@ static bool common_download_model(
}
// Download in parallel
common_download_file_multiple(urls, bearer_token);
common_download_file_multiple(urls, bearer_token, offline);
}
return true;
@ -609,7 +615,7 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
*
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
*/
static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token) {
static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token, bool offline) {
auto parts = string_split<std::string>(hf_repo_with_tag, ':');
std::string tag = parts.size() > 1 ? parts.back() : "latest";
std::string hf_repo = parts[0];
@ -639,20 +645,25 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_
long res_code = 0;
std::string res_str;
bool use_cache = false;
if (!offline) {
try {
auto res = common_remote_get_content(url, params);
res_code = res.first;
res_str = std::string(res.second.data(), res.second.size());
} catch (const std::exception & e) {
LOG_WRN("error: failed to get manifest: %s\n", e.what());
LOG_WRN("try reading from cache\n");
// try to read from cache
try {
LOG_WRN("error: failed to get manifest at %s: %s\n", url.c_str(), e.what());
}
}
if (res_code == 0) {
if (std::filesystem::exists(cached_response_path)) {
LOG_WRN("trying to read manifest from cache: %s\n", cached_response_path.c_str());
res_str = read_file(cached_response_path);
res_code = 200;
use_cache = true;
} catch (const std::exception & e) {
throw std::runtime_error("error: failed to get manifest (check your internet connection)");
} else {
throw std::runtime_error(
offline ? "error: failed to get manifest (offline mode)"
: "error: failed to get manifest (check your internet connection)");
}
}
std::string ggufFile;
@ -699,24 +710,25 @@ bool common_has_curl() {
return false;
}
static bool common_download_file_single(const std::string &, const std::string &, const std::string &) {
static bool common_download_file_single(const std::string &, const std::string &, const std::string &, bool) {
LOG_ERR("error: built without CURL, cannot download model from internet\n");
return false;
}
static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> &, const std::string &) {
static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> &, const std::string &, bool) {
LOG_ERR("error: built without CURL, cannot download model from the internet\n");
return false;
}
static bool common_download_model(
const common_params_model &,
const std::string &) {
const std::string &,
bool) {
LOG_ERR("error: built without CURL, cannot download model from the internet\n");
return false;
}
static struct common_hf_file_res common_get_hf_file(const std::string &, const std::string &) {
static struct common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool) {
LOG_ERR("error: built without CURL, cannot download model from the internet\n");
return {};
}
@ -743,7 +755,8 @@ struct handle_model_result {
static handle_model_result common_params_handle_model(
struct common_params_model & model,
const std::string & bearer_token,
const std::string & model_path_default) {
const std::string & model_path_default,
bool offline) {
handle_model_result result;
// handle pre-fill default model path and url based on hf_repo and hf_file
{
@ -751,7 +764,7 @@ static handle_model_result common_params_handle_model(
// short-hand to avoid specifying --hf-file -> default it to --model
if (model.hf_file.empty()) {
if (model.path.empty()) {
auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token);
auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token, offline);
if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) {
exit(1); // built without CURL, error message already printed
}
@ -792,7 +805,7 @@ static handle_model_result common_params_handle_model(
// then, download it if needed
if (!model.url.empty()) {
bool ok = common_download_model(model, bearer_token);
bool ok = common_download_model(model, bearer_token, offline);
if (!ok) {
LOG_ERR("error: failed to download model from %s\n", model.url.c_str());
exit(1);
@ -935,7 +948,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
// handle model and download
{
auto res = common_params_handle_model(params.model, params.hf_token, DEFAULT_MODEL_PATH);
auto res = common_params_handle_model(params.model, params.hf_token, DEFAULT_MODEL_PATH, params.offline);
if (params.no_mmproj) {
params.mmproj = {};
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
@ -945,12 +958,12 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
// only download mmproj if the current example is using it
for (auto & ex : mmproj_examples) {
if (ctx_arg.ex == ex) {
common_params_handle_model(params.mmproj, params.hf_token, "");
common_params_handle_model(params.mmproj, params.hf_token, "", params.offline);
break;
}
}
common_params_handle_model(params.speculative.model, params.hf_token, "");
common_params_handle_model(params.vocoder.model, params.hf_token, "");
common_params_handle_model(params.speculative.model, params.hf_token, "", params.offline);
common_params_handle_model(params.vocoder.model, params.hf_token, "", params.offline);
}
if (params.escape) {
@ -2997,6 +3010,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
common_log_set_verbosity_thold(INT_MAX);
}
));
add_opt(common_arg(
{"--offline"},
"Offline mode: forces use of cache, prevents network access",
[](common_params & params) {
params.offline = true;
}
).set_env("LLAMA_OFFLINE"));
add_opt(common_arg(
{"-lv", "--verbosity", "--log-verbosity"}, "N",
"Set the verbosity threshold. Messages with a higher verbosity will be ignored.",

View file

@ -170,20 +170,23 @@ std::string common_chat_msg_parser::consume_rest() {
}
// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback.
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from) {
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from, bool add_prelude_to_content) {
auto m = regex.search(input_, from == std::string::npos ? pos_ : from);
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
return std::nullopt;
}
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
pos_ = m.groups[0].end;
if (add_prelude_to_content) {
add_content(prelude);
}
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
if (is_partial()) {
throw common_chat_msg_partial_exception(regex.str());
}
return std::nullopt;
}
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
pos_ = m.groups[0].end;
return find_regex_result{prelude, m.groups};
}

View file

@ -30,6 +30,7 @@ class common_chat_msg_parser {
const std::string & healing_marker() const { return healing_marker_; }
const bool & is_partial() const { return is_partial_; }
const common_chat_msg & result() const { return result_; }
const common_chat_syntax & syntax() const { return syntax_; }
void move_to(size_t pos) {
if (pos > input_.size()) {
@ -77,7 +78,7 @@ class common_chat_msg_parser {
std::vector<common_string_range> groups;
};
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos);
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true);
bool try_consume_literal(const std::string & literal);

View file

@ -31,6 +31,11 @@ static std::string string_diff(const std::string & last, const std::string & cur
return current;
}
if (!string_starts_with(current, last)) {
if (string_starts_with(last, current)) {
// This happens if the last generation ended on a partial stop word (not erased),
// and the current ended on a stop word (erased).
return "";
}
throw std::runtime_error("Invalid diff: '" + last + "' not found at start of '" + current + "'");
}
return current.substr(last.size());
@ -101,9 +106,9 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
if (!args_diff.empty() || pref.id != newf.id) {
auto & diff = diffs.emplace_back();
diff.tool_call_index = idx;
diff.tool_call_delta.name = newf.name;
if (pref.id != newf.id) {
diff.tool_call_delta.id = newf.id;
diff.tool_call_delta.name = newf.name;
}
diff.tool_call_delta.arguments = args_diff;
}
@ -387,22 +392,19 @@ template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_di
delta["content"] = diff.content_delta;
}
if (diff.tool_call_index != std::string::npos) {
json tool_call;
tool_call["index"] = diff.tool_call_index;
if (!diff.tool_call_delta.id.empty()) {
tool_call["id"] = diff.tool_call_delta.id;
tool_call["type"] = "function";
}
json function = json::object();
if (!diff.tool_call_delta.name.empty()) {
function["name"] = diff.tool_call_delta.name;
}
if (!diff.tool_call_delta.id.empty()) {
function["id"] = diff.tool_call_delta.id;
}
if (!diff.tool_call_delta.arguments.empty()) {
function["arguments"] = diff.tool_call_delta.arguments;
}
delta["tool_calls"] = json::array({
json {
{"index", diff.tool_call_index},
{"function", function}
}
});
tool_call["function"] = function;
delta["tool_calls"] = json::array({tool_call});
}
return delta;
}
@ -654,7 +656,6 @@ static void parse_json_tool_calls(
}
from = std::string::npos;
builder.add_content(res->prelude);
auto maybe_raw_python = name == "python" && allow_raw_python;
if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) {
if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) {
@ -684,7 +685,6 @@ static void parse_json_tool_calls(
};
if (block_open) {
if (auto res = builder.try_find_regex(*block_open)) {
builder.add_content(res->prelude);
parse_tool_calls();
} else {
builder.add_content(builder.consume_rest());
@ -697,7 +697,6 @@ static void parse_json_tool_calls(
static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, const common_regex & prefix, size_t rstrip_prefix = 0) {
static const std::vector<std::vector<std::string>> args_paths = {{"arguments"}};
if (auto res = builder.try_find_regex(prefix)) {
builder.add_content(res->prelude);
builder.move_back(rstrip_prefix);
auto tool_calls = builder.consume_json_with_dumped_args(args_paths);
if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) {
@ -833,6 +832,10 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
return data;
}
static void common_chat_parse_generic(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const std::vector<std::vector<std::string>> content_paths = {
{"response"},
};
@ -905,6 +908,11 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
return data;
}
static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
parse_prefixed_json_tool_call_array(builder, prefix);
}
@ -999,7 +1007,6 @@ static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {
if (auto res = builder.try_find_regex(start_action_regex)) {
// If we didn't extract thoughts, prelude includes them.
builder.add_content(res->prelude);
auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}});
for (const auto & tool_call : tool_calls.value) {
std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : "";
@ -1014,11 +1021,7 @@ static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {
}
builder.consume_regex(end_action_regex);
} else if (auto res = builder.try_find_regex(start_response_regex)) {
// If we didn't extract thoughts, prelude includes them.
builder.add_content(res->prelude);
if (auto res = builder.try_find_regex(end_response_regex)) {
builder.add_content(res->prelude);
} else {
if (!builder.try_find_regex(end_response_regex)) {
builder.add_content(builder.consume_rest());
throw common_chat_msg_partial_exception(end_response_regex.str());
}
@ -1126,6 +1129,11 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
return data;
}
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex function_regex(
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
static const common_regex close_regex("\\}\\s*");
@ -1136,8 +1144,6 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w
if (with_builtin_tools) {
static const common_regex builtin_call_regex("<\\|python_tag\\|>");
if (auto res = builder.try_find_regex(builtin_call_regex)) {
builder.add_content(res->prelude);
auto fun_res = builder.consume_regex(function_name_regex);
auto function_name = builder.str(fun_res.groups[1]);
@ -1253,6 +1259,10 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
}
static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex tool_calls_begin("(?:<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>|<tool▁calls>)");
static const common_regex tool_calls_end("<tool▁calls▁end>");
@ -1314,6 +1324,10 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
return data;
}
static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex prefix(regex_escape(" functools["));
parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1);
}
@ -1455,15 +1469,12 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
return data;
}
static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) {
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static const common_regex python_tag_regex(regex_escape("<|python_tag|>"));
if (auto res = builder.try_find_regex(python_tag_regex)) {
builder.add_content(res->prelude);
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
builder.add_tool_call("python", "", arguments);
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static const common_regex python_tag_regex(regex_escape("<|python_tag|>"));
static const common_regex function_regex(R"(<function=(\w+)>)");
static const common_regex close_regex(R"(</function>)");
@ -1475,6 +1486,12 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser
function_regex,
close_regex,
std::nullopt);
if (auto res = builder.try_find_regex(python_tag_regex)) {
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
builder.add_tool_call("python", "", arguments);
return;
}
}
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
@ -1593,6 +1610,10 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
}
static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex open_regex(
"(?:"
@ -1614,8 +1635,6 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
);
if (auto res = builder.try_find_regex(open_regex)) {
builder.add_content(res->prelude);
const auto & block_start = res->groups[1];
std::string block_end = block_start.empty() ? "" : "```";
@ -1851,10 +1870,10 @@ static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
builder.add_content(builder.consume_rest());
}
static void common_chat_parse(common_chat_msg_parser & builder, common_chat_format format) {
LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(format), builder.input().c_str());
static void common_chat_parse(common_chat_msg_parser & builder) {
LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str());
switch (format) {
switch (builder.syntax().format) {
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
common_chat_parse_content_only(builder);
break;
@ -1889,7 +1908,7 @@ static void common_chat_parse(common_chat_msg_parser & builder, common_chat_form
common_chat_parse_command_r7b(builder);
break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(format));
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
builder.finish();
}
@ -1897,7 +1916,7 @@ static void common_chat_parse(common_chat_msg_parser & builder, common_chat_form
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
common_chat_msg_parser builder(input, is_partial, syntax);
try {
common_chat_parse(builder, syntax.format);
common_chat_parse(builder);
} catch (const common_chat_msg_partial_exception & ex) {
LOG_DBG("Partial parse: %s\n", ex.what());
if (!is_partial) {

View file

@ -144,6 +144,7 @@ struct common_chat_syntax {
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
bool reasoning_in_content = false;
bool thinking_forced_open = false;
bool parse_tool_calls = true;
};
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid

View file

@ -287,6 +287,7 @@ struct common_params {
int32_t verbosity = 0;
int32_t control_vector_layer_start = -1; // layer range for control vector
int32_t control_vector_layer_end = -1; // layer range for control vector
bool offline = false;
int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line

View file

@ -10,8 +10,8 @@ Proof of concept:
``` sh
export model_name=llama_3.2-1b && export quantization=f32
./build/bin/finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512
./build/bin/perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf
./build/bin/llama-finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512
./build/bin/llama-perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf
```
The perplexity value of the finetuned model should be lower after training on the test set for 2 epochs.

View file

@ -1604,6 +1604,9 @@ void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
for (int i = 0; i < sched->n_backends; i++) {
ggml_backend_synchronize(sched->backends[i]);
}
// reset the current copy to 0 so that the graphs will be similar during generation
// necessary for CUDA graphs
sched->cur_copy = 0;
}
void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {

View file

@ -168,7 +168,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
#if !defined(GGML_USE_HIP)
#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
static const char * cu_get_error_str(CUresult err) {
const char * err_str;
cuGetErrorString(err, &err_str);

View file

@ -6476,6 +6476,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
case GGML_OP_ROPE:
case GGML_OP_RMS_NORM:
case GGML_OP_CONV_2D_DW:
case GGML_OP_IM2COL:
return true;
default:
return false;

View file

@ -615,11 +615,11 @@ extern "C" {
// Returns the number of tokens in the KV cache (slow, use only for debug)
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx),
"Use llama_kv_self_seq_pos_max() instead");
"Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx),
"Use llama_kv_self_seq_pos_max() instead");
"Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
// Clear the KV cache - both cell info is erased and KV data is zeroed
LLAMA_API void llama_kv_self_clear(

View file

@ -693,12 +693,18 @@ int llama_context::encode(llama_batch & inp_batch) {
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
// TODO: move the validation to the llama_batch_allocr
if (batch.token) {
for (int32_t i = 0; i < n_tokens; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return -1;
}
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
throw -1;
}
}
}
@ -852,7 +858,7 @@ int llama_context::encode(llama_batch & inp_batch) {
int llama_context::decode(llama_batch & inp_batch) {
if (!memory) {
LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
return encode(inp_batch);
}
@ -887,11 +893,17 @@ int llama_context::decode(llama_batch & inp_batch) {
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
// TODO: move the validation to the llama_batch_allocr
if (batch.token) {
for (int64_t i = 0; i < n_tokens_all; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
throw std::runtime_error("invalid token");
return -1;
}
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
return -1;
}
}
}

View file

@ -286,31 +286,11 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
}
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
llama_pos result = std::numeric_limits<llama_pos>::max();
for (uint32_t i = 0; i < cells.size(); ++i) {
if (cells.seq_has(i, seq_id)) {
result = std::min(result, cells.pos_get(i));
}
}
if (result == std::numeric_limits<llama_pos>::max()) {
result = -1;
}
return result;
return cells.seq_pos_min(seq_id);
}
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
llama_pos result = -1;
for (uint32_t i = 0; i < cells.size(); ++i) {
if (cells.seq_has(i, seq_id)) {
result = std::max(result, cells.pos_get(i));
}
}
return result;
return cells.seq_pos_max(seq_id);
}
void llama_kv_cache_unified::restore() {
@ -504,7 +484,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cell_max(), n_pad)));
n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
#ifdef FIND_SLOT_DEBUG
LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
@ -1018,7 +998,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
const uint32_t n_layer = layers.size();
const uint32_t n_kv = cell_max();
const uint32_t n_kv = cells.used_max_p1();
const uint32_t n_used = cells.get_used();
assert(n_used <= n_kv);
@ -1144,16 +1124,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
return true;
}
uint32_t llama_kv_cache_unified::cell_max() const {
for (uint32_t i = cells.size(); i > 0; --i) {
if (!cells.is_empty(i - 1)) {
return i;
}
}
return 0;
}
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
assert(p0 >= 0 && p1 >= 0);

View file

@ -246,10 +246,6 @@ private:
// return true if cells have been moved
bool defrag_prepare(int32_t n_max_nodes);
// find how many cells are currently in use
// TODO: optimize
uint32_t cell_max() const;
size_t total_size() const;
size_t size_k_bytes() const;

View file

@ -6,6 +6,7 @@
#include <bitset>
#include <cassert>
#include <vector>
#include <set>
// meta information about KV cells that can be part of multiple sequences at the same time
// TODO: add unit tests
@ -18,8 +19,13 @@ public:
seq[i].reset();
}
used = 0;
has_shift = false;
used.clear();
for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
seq_pos[s].clear();
}
}
void reset_shift() {
@ -50,7 +56,25 @@ public:
}
uint32_t get_used() const {
return used;
return used.size();
}
// the index of the first cell that is used
// return 0 if no cells are used
uint32_t used_min() const {
return used.empty() ? 0 : *used.begin();
}
// the index of the last cell that is used + 1
// return 0 if no cells are used
uint32_t used_max_p1() const {
#if 0
if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin());
if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin());
if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin());
#endif
return used.empty() ? 0 : *used.rbegin() + 1;
}
bool get_has_shift() const {
@ -69,6 +93,9 @@ public:
pos [isrc] = -1;
shift[isrc] = 0;
seq [isrc].reset();
used.erase (isrc);
used.insert(idst);
}
// copy the state of cells [i, i + n) (used for save/restore the state of the cells)
@ -95,16 +122,24 @@ public:
for (uint32_t j = 0; j < other.pos.size(); ++j) {
if (pos[i + j] == -1 && other.pos[j] != -1) {
used++;
used.insert(i + j);
}
if (pos[i + j] != -1 && other.pos[j] == -1) {
used--;
used.erase(i + j);
}
if (pos[i + j] != -1) {
seq_pos_rm(i + j);
}
pos[i + j] = other.pos[j];
seq[i + j] = other.seq[j];
if (pos[i + j] != -1) {
seq_pos_add(i + j);
}
assert(shift[i + j] == 0);
}
}
@ -118,11 +153,12 @@ public:
assert(seq_id >= 0);
seq[i].reset(seq_id);
seq_pos[seq_id].erase(pos[i]);
if (seq[i].none()) {
pos[i] = -1;
used--;
used.erase(i);
return true;
}
@ -135,17 +171,22 @@ public:
assert(i < pos.size());
if (seq[i].test(seq_id)) {
seq_pos_rm(i);
seq[i].reset();
seq[i].set(seq_id);
seq_pos[seq_id].insert(pos[i]);
return false;
}
if (seq[i].any()) {
seq_pos_rm(i);
seq[i].reset();
pos[i] = -1;
used--;
used.erase(i);
return true;
}
@ -169,6 +210,33 @@ public:
assert(!seq[i].test(seq_id));
seq[i].set(seq_id);
seq_pos[seq_id].insert(pos[i]);
}
// the minimum position of sequence seq_id currently present in any of the cells
// return -1 if the sequence is not present
llama_pos seq_pos_min(llama_seq_id seq_id) const {
assert(seq_id >= 0);
assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
if (seq_pos[seq_id].empty()) {
return -1;
}
return *seq_pos[seq_id].begin();
}
// the maximum position of sequence seq_id currently present in any of the cells
// return -1 if the sequence is not present
llama_pos seq_pos_max(llama_seq_id seq_id) const {
assert(seq_id >= 0);
assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
if (seq_pos[seq_id].empty()) {
return -1;
}
return *seq_pos[seq_id].rbegin();
}
// note: call only if the cell is not empty
@ -202,7 +270,8 @@ public:
assert(pos[i] == -1);
pos[i] = p;
used++;
used.insert(i);
}
// pos[i] = pos[i] + d
@ -212,16 +281,22 @@ public:
assert(i < pos.size());
assert(pos[i] != -1);
seq_pos_rm(i);
pos[i] += d;
shift[i] += d;
seq_pos_add(i);
has_shift = true;
if (pos[i] < 0) {
pos[i] = -1;
seq[i].reset();
seq_pos_rm(i);
used--;
seq[i].reset();
pos[i] = -1;
used.erase(i);
return true;
}
@ -238,17 +313,22 @@ public:
const llama_pos p_old = pos[i];
seq_pos_rm(i);
pos[i] /= d;
shift[i] += p_old - pos[i];
seq_pos_add(i);
has_shift = true;
}
private:
uint32_t used = 0; // used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
bool has_shift = false;
// set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
std::set<uint32_t> used;
std::vector<llama_pos> pos;
// this array accumulates any applied shifts to the pos array since the last reset_shift() call
@ -268,6 +348,32 @@ private:
//
std::vector<llama_pos> shift;
std::vector<std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>> seq;
};
using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
std::vector<bits_t> seq;
// the set seq_pos[s] tells us which positions are currently present for sequence s
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
// helper functions for updating `seq_pos`, once cell at a time:
// remove cell i
void seq_pos_rm(uint32_t i) {
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
if (seq[i].test(s)) {
seq_pos[s].erase(pos[i]);
}
}
}
// add cell i
void seq_pos_add(uint32_t i) {
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
if (seq[i].test(s)) {
seq_pos[s].insert(pos[i]);
}
}
}
};

View file

@ -798,7 +798,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d
}
// if we have enough values the operation was a success
if (filtered_tokens.size() >= ctx->min_keep) {
if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) {
memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
cur_p->size = filtered_tokens.size();
min_p_applied = true;
@ -909,7 +909,7 @@ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token
cum_sum += cur_p->data[idx].p;
// Check if the running sum is greater than typical or if we have kept at least min_keep tokens
if (cum_sum > ctx->p && i >= ctx->min_keep - 1) {
if (cum_sum > ctx->p && (ctx->min_keep == 0 || i >= ctx->min_keep - 1)) {
last_idx = i + 1;
break;
}

View file

@ -364,6 +364,7 @@ struct server_task {
params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format;
params.oaicompat_chat_syntax.reasoning_in_content = params.stream;
params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false);
}
{
@ -3394,13 +3395,7 @@ struct server_context {
batch.logits + i,
};
int ret = 0;
if (do_encode) {
ret = llama_encode(ctx, batch_view);
} else {
ret = llama_decode(ctx, batch_view);
}
const int ret = llama_decode(ctx, batch_view);
metrics.on_decoded(slots);

View file

@ -121,6 +121,30 @@ def test_completion_stream_with_openai_library():
assert match_regex("(going|bed)+", output_text)
# Test case from https://github.com/ggml-org/llama.cpp/issues/13780
@pytest.mark.slow
def test_completion_stream_with_openai_library_stops():
global server
server.model_hf_repo = "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M"
server.model_hf_file = None
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.completions.create(
model="davinci-002",
prompt="System: You are helpfull assistant.\nAssistant:\nHey! How could I help?\nUser:\nTell me a joke.\nAssistant:\n",
stop=["User:\n", "Assistant:\n"],
max_tokens=200,
stream=True,
)
output_text = ''
for data in res:
choice = data.choices[0]
if choice.finish_reason is None:
assert choice.text is not None
output_text += choice.text
assert match_regex("Sure, here's one for[\\s\\S]*", output_text), f'Unexpected output: {output_text}'
@pytest.mark.parametrize("n_slots", [1, 2])
def test_consistent_result_same_seed(n_slots: int):
global server

View file

@ -328,6 +328,10 @@ class ServerProcess:
if 'function' not in tc:
raise ValueError(f"Expected function type, got {tc['type']}")
if tc['index'] >= len(tool_calls):
assert 'id' in tc
assert tc.get('type') == 'function'
assert 'function' in tc and 'name' in tc['function'] and len(tc['function']['name']) > 0, \
f"Expected function call with name, got {tc.get('function')}"
tool_calls.append(dict(
id="",
type="function",
@ -340,10 +344,10 @@ class ServerProcess:
if tc.get('id') is not None:
tool_call['id'] = tc['id']
fct = tc['function']
assert 'id' not in fct, f"Function call should not have id: {fct}"
if fct.get('name') is not None:
tool_call['function']['name'] = fct['name']
tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name']
if fct.get('arguments') is not None:
assert len(fct['arguments']) > 0, f'Expected non empty arguments delta!'
tool_call['function']['arguments'] += fct['arguments']
print(f'Streamed response had {content_parts} content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')

View file

@ -735,9 +735,12 @@ static json oaicompat_chat_params_parse(
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
inputs.reasoning_format = opt.reasoning_format;
inputs.enable_thinking = opt.enable_thinking;
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
if (body.contains("grammar")) {
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
}
llama_params["parse_tool_calls"] = true;
}
// if the assistant message appears at the end of list, we do not add end-of-turn token
// for ex. this can be useful to modify the reasoning process in reasoning models