Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.devops/cuda.Dockerfile
#	.devops/musa.Dockerfile
#	.github/workflows/build.yml
#	README.md
#	docs/docker.md
#	examples/imatrix/imatrix.cpp
#	examples/llama-bench/llama-bench.cpp
#	examples/main/README.md
#	examples/perplexity/perplexity.cpp
#	examples/server/README.md
#	ggml/src/ggml-cpu/ggml-cpu.c
#	ggml/src/ggml-cuda/CMakeLists.txt
#	models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja
#	models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja
#	scripts/get_chat_template.py
#	scripts/sync-ggml.last
#	tests/test-chat.cpp
#	tests/test-gguf.cpp
#	tests/test-sampling.cpp
This commit is contained in:
Concedo 2025-02-15 00:49:46 +08:00
commit 754fef5204
27 changed files with 1845 additions and 548 deletions

View file

@ -366,6 +366,112 @@ static void common_params_print_usage(common_params_context & ctx_arg) {
print_options(specific_options); print_options(specific_options);
} }
static void common_params_print_completion(common_params_context & ctx_arg) {
std::vector<common_arg *> common_options;
std::vector<common_arg *> sparam_options;
std::vector<common_arg *> specific_options;
for (auto & opt : ctx_arg.options) {
if (opt.is_sparam) {
sparam_options.push_back(&opt);
} else if (opt.in_example(ctx_arg.ex)) {
specific_options.push_back(&opt);
} else {
common_options.push_back(&opt);
}
}
printf("_llama_completions() {\n");
printf(" local cur prev opts\n");
printf(" COMPREPLY=()\n");
printf(" cur=\"${COMP_WORDS[COMP_CWORD]}\"\n");
printf(" prev=\"${COMP_WORDS[COMP_CWORD-1]}\"\n\n");
printf(" opts=\"");
auto print_options = [](const std::vector<common_arg *> & options) {
for (const common_arg * opt : options) {
for (const char * arg : opt->args) {
printf("%s ", arg);
}
}
};
print_options(common_options);
print_options(sparam_options);
print_options(specific_options);
printf("\"\n\n");
printf(" case \"$prev\" in\n");
printf(" --model)\n");
printf(" COMPREPLY=( $(compgen -f -X '!*.gguf' -- \"$cur\") $(compgen -d -- \"$cur\") )\n");
printf(" return 0\n");
printf(" ;;\n");
printf(" --grammar-file)\n");
printf(" COMPREPLY=( $(compgen -f -X '!*.gbnf' -- \"$cur\") $(compgen -d -- \"$cur\") )\n");
printf(" return 0\n");
printf(" ;;\n");
printf(" --chat-template-file)\n");
printf(" COMPREPLY=( $(compgen -f -X '!*.jinja' -- \"$cur\") $(compgen -d -- \"$cur\") )\n");
printf(" return 0\n");
printf(" ;;\n");
printf(" *)\n");
printf(" COMPREPLY=( $(compgen -W \"${opts}\" -- \"$cur\") )\n");
printf(" return 0\n");
printf(" ;;\n");
printf(" esac\n");
printf("}\n\n");
std::set<std::string> executables = {
"llama-batched",
"llama-batched-bench",
"llama-bench",
"llama-cli",
"llama-convert-llama2c-to-ggml",
"llama-cvector-generator",
"llama-embedding",
"llama-eval-callback",
"llama-export-lora",
"llama-gbnf-validator",
"llama-gen-docs",
"llama-gguf",
"llama-gguf-hash",
"llama-gguf-split",
"llama-gritlm",
"llama-imatrix",
"llama-infill",
"llama-llava-cli",
"llama-llava-clip-quantize-cli",
"llama-lookahead",
"llama-lookup",
"llama-lookup-create",
"llama-lookup-merge",
"llama-lookup-stats",
"llama-minicpmv-cli",
"llama-parallel",
"llama-passkey",
"llama-perplexity",
"llama-q8dot",
"llama-quantize",
"llama-quantize-stats",
"llama-qwen2vl-cli",
"llama-retrieval",
"llama-run",
"llama-save-load-state",
"llama-server",
"llama-simple",
"llama-simple-chat",
"llama-speculative",
"llama-speculative-simple",
"llama-tokenize",
"llama-tts",
"llama-vdot"
};
for (const auto& exe : executables) {
printf("complete -F _llama_completions %s\n", exe.c_str());
}
}
static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & value) { static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & value) {
std::vector<ggml_backend_dev_t> devices; std::vector<ggml_backend_dev_t> devices;
auto dev_names = string_split<std::string>(value, ','); auto dev_names = string_split<std::string>(value, ',');
@ -427,6 +533,10 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
} }
exit(0); exit(0);
} }
if (ctx_arg.params.completion) {
common_params_print_completion(ctx_arg);
exit(0);
}
} catch (const std::invalid_argument & ex) { } catch (const std::invalid_argument & ex) {
fprintf(stderr, "%s\n", ex.what()); fprintf(stderr, "%s\n", ex.what());
ctx_arg.params = params_org; ctx_arg.params = params_org;
@ -495,6 +605,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
exit(0); exit(0);
} }
)); ));
add_opt(common_arg(
{"--completion-bash"},
"print source-able bash completion script for llama.cpp",
[](common_params & params) {
params.completion = true;
}
));
add_opt(common_arg( add_opt(common_arg(
{"--verbose-prompt"}, {"--verbose-prompt"},
string_format("print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false"), string_format("print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false"),
@ -947,6 +1064,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sampling.min_p = std::stof(value); params.sampling.min_p = std::stof(value);
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg(
{"--top-nsigma"}, "N",
string_format("top-n-sigma sampling (default: %.1f, -1.0 = disabled)", params.sampling.top_n_sigma),
[](common_params & params, const std::string & value) {
params.sampling.top_n_sigma = std::stof(value);
}
).set_examples({LLAMA_EXAMPLE_MAIN}).set_sparam());
add_opt(common_arg( add_opt(common_arg(
{"--xtc-probability"}, "N", {"--xtc-probability"}, "N",
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability), string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
@ -1976,6 +2100,17 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.use_jinja = true; params.use_jinja = true;
} }
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA")); ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA"));
add_opt(common_arg(
{"--reasoning-format"}, "FORMAT",
"reasoning format (default: deepseek; allowed values: deepseek, none)\n"
"controls whether thought tags are extracted from the response, and in which format they're returned. 'none' leaves thoughts unparsed in `message.content`, 'deepseek' puts them in `message.reasoning_content` (for DeepSeek R1 & Command R7B only).\n"
"only supported for non-streamed responses",
[](common_params & params, const std::string & value) {
/**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; }
else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; }
else { std::invalid_argument("invalid value"); }
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK"));
add_opt(common_arg( add_opt(common_arg(
{"--chat-template"}, "JINJA_TEMPLATE", {"--chat-template"}, "JINJA_TEMPLATE",
string_format( string_format(

View file

@ -12,11 +12,13 @@ std::string common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x"; case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools"; case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1"; case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING: return "DeepSeek R1 (extract reasoning)";
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2"; case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2"; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1"; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro"; case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B"; case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: return "Command R7B (extract reasoning)";
default: default:
throw std::runtime_error("Unknown chat format"); throw std::runtime_error("Unknown chat format");
} }
@ -105,7 +107,6 @@ static common_chat_msg parse_json_tool_calls(
std::sregex_iterator rend; std::sregex_iterator rend;
std::sregex_iterator rit(it, end, function_regex); std::sregex_iterator rit(it, end, function_regex);
if (rit == rend) { if (rit == rend) {
fprintf(stderr, "No more tool calls found\n");
result.content += std::string(it, end); result.content += std::string(it, end);
break; break;
} }
@ -115,14 +116,21 @@ static common_chat_msg parse_json_tool_calls(
json arguments; json arguments;
if (!parse_json(it, end, arguments)) { if (!parse_json(it, end, arguments)) {
throw std::runtime_error("Failed to parse json tool call arguments"); throw std::runtime_error("Failed to parse json tool call arguments: " + input);
} }
if (!std::regex_search(it, end, match, close_regex)) { if (!std::regex_search(it, end, match, close_regex)) {
throw std::runtime_error("Malformed input, missing closing pattern"); throw std::runtime_error("Malformed input, missing closing pattern: " + input);
} }
it = match.suffix().first; it = match.suffix().first;
result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""}); result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
} }
if (!result.tool_calls.empty()) {
if (!string_strip(result.content).empty()) {
LOG_WRN("Content found with tool calls: %s\n", result.content.c_str());
}
result.content = "";
}
return result; return result;
} }
@ -134,11 +142,11 @@ static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& in
result.role = "assistant"; result.role = "assistant";
const auto process_tool_calls = [&](const json & tool_calls) { const auto process_tool_calls = [&](const json & tool_calls) {
for (const auto & tool_call : tool_calls) { for (const auto & tool_call : tool_calls) {
const auto & arguments = tool_call["arguments"]; const auto & arguments = tool_call.at("arguments");
result.tool_calls.push_back({ result.tool_calls.push_back({
tool_call["name"], tool_call.at("name"),
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
tool_call.contains("id") ? tool_call["id"] : "", tool_call.contains("id") ? tool_call.at("id") : "",
}); });
} }
}; };
@ -155,7 +163,7 @@ static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& in
static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) { static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
for (const auto & tool : tools) { for (const auto & tool : tools) {
if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) { if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str()); LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str());
continue; continue;
} }
@ -190,27 +198,27 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
auto tool_call_schemas = json::array(); auto tool_call_schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) { foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"]; const auto & function = tool.at("function");
auto tool_schema = json { auto tool_schema = json {
{"type", "object"}, {"type", "object"},
{"properties", { {"properties", {
{"name", { {"name", {
{"type", "string"}, {"type", "string"},
{"const", function["name"]}, {"const", function.at("name")},
}}, }},
{"arguments", function["parameters"]}, {"arguments", function.at("parameters")},
}}, }},
{"required", json::array({"name", "arguments"})}, {"required", json::array({"name", "arguments"})},
}; };
if (function.contains("description")) { if (function.contains("description")) {
tool_schema["description"] = function["description"]; tool_schema["description"] = function.at("description");
} }
if (inputs.parallel_tool_calls) { if (inputs.parallel_tool_calls) {
tool_schema["properties"]["id"] = { tool_schema.at("properties")["id"] = {
{"type", "string"}, {"type", "string"},
{"minLength", 4}, {"minLength", 4},
}; };
tool_schema["required"].push_back("id"); tool_schema.at("required").push_back("id");
} }
tool_call_schemas.emplace_back(tool_schema); tool_call_schemas.emplace_back(tool_schema);
}); });
@ -275,21 +283,21 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) {
common_chat_msg result; common_chat_msg result;
result.role = "assistant"; result.role = "assistant";
if (data.contains("tool_calls")) { if (data.contains("tool_calls")) {
for (const auto & tool_call : data["tool_calls"]) { for (const auto & tool_call : data.at("tool_calls")) {
result.tool_calls.push_back({ result.tool_calls.push_back({
tool_call["name"], tool_call.at("name"),
tool_call["arguments"].dump(), tool_call.at("arguments").dump(),
tool_call.contains("id") ? tool_call["id"] : "", tool_call.contains("id") ? tool_call.at("id") : "",
}); });
} }
} else if (data.contains("tool_call")) { } else if (data.contains("tool_call")) {
result.tool_calls.push_back({ result.tool_calls.push_back({
data["tool_call"]["name"], data.at("tool_call").at("name"),
data["tool_call"]["arguments"].dump(), data.at("tool_call").at("arguments").dump(),
/* id= */ "", /* id= */ "",
}); });
} else if (data.contains("response")) { } else if (data.contains("response")) {
const auto & response = data["response"]; const auto & response = data.at("response");
result.content = response.is_string() ? response.get<std::string>() : response.dump(2); result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
} }
return result; return result;
@ -301,7 +309,7 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array(); auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) { foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"]; const auto & function = tool.at("function");
schemas.push_back({ schemas.push_back({
{"type", "object"}, {"type", "object"},
{"properties", { {"properties", {
@ -309,9 +317,9 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
// It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object. // It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object.
{"name", { {"name", {
{"type", "string"}, {"type", "string"},
{"const", function["name"]}, {"const", function.at("name")},
}}, }},
{"arguments", function["parameters"]}, {"arguments", function.at("parameters")},
{"id", { {"id", {
{"type", "string"}, {"type", "string"},
// Nemo's template expects a 9-character alphanumeric ID. // Nemo's template expects a 9-character alphanumeric ID.
@ -346,7 +354,7 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array(); auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) { foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"]; const auto & function = tool.at("function");
schemas.push_back({ schemas.push_back({
{"type", "object"}, {"type", "object"},
{"properties", { {"properties", {
@ -357,9 +365,9 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
}}, }},
{"tool_name", { {"tool_name", {
{"type", "string"}, {"type", "string"},
{"const", function["name"]}, {"const", function.at("name")},
}}, }},
{"parameters", function["parameters"]}, {"parameters", function.at("parameters")},
}}, }},
{"required", json::array({"tool_call_id", "tool_name", "parameters"})}, {"required", json::array({"tool_call_id", "tool_name", "parameters"})},
}); });
@ -382,39 +390,65 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
"<|END_THINKING|>", "<|END_THINKING|>",
"<|END_ACTION|>", "<|END_ACTION|>",
}; };
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); auto adjusted_messages = json::array();
data.format = COMMON_CHAT_FORMAT_COMMAND_R7B; for (const auto & msg : inputs.messages) {
auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string();
auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array();
if (has_reasoning_content && has_tool_calls) {
auto adjusted_message = msg;
adjusted_message["tool_plan"] = msg.at("reasoning_content");
adjusted_message.erase("reasoning_content");
adjusted_messages.push_back(adjusted_message);
} else {
adjusted_messages.push_back(msg);
}
}
data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {});
data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING : COMMON_CHAT_FORMAT_COMMAND_R7B;
return data; return data;
} }
static common_chat_msg common_chat_parse_command_r7b(const std::string & input) { static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) {
static std::regex response_regex("<\\|START_RESPONSE\\|>([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>"); static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|>)([\\s\\S\\n\\r]*)");
static std::regex thought_action_regex("<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|><\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>"); static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>");
std::smatch match; std::smatch match;
common_chat_msg result; common_chat_msg result;
result.role = "assistant"; result.role = "assistant";
if (std::regex_match(input, match, response_regex)) {
std::string rest = input;
if (std::regex_match(rest, match, thought_regex)) {
if (extract_reasoning) {
result.reasoning_content = match[2].str();
} else if (!match[2].str().empty()) {
// Let the unparsed thinking tags through in content only if their insides aren't empty.
result.content = match[1].str(); result.content = match[1].str();
} else if (std::regex_match(input, match, thought_action_regex)) { }
result.tool_plan = match[1].str(); rest = match[3].str();
auto actions_str = match[2].str(); }
if (std::regex_match(rest, match, action_regex)) {
auto actions_str = match[1].str();
auto actions = json::parse(actions_str); auto actions = json::parse(actions_str);
for (const auto & action : actions) { for (const auto & action : actions) {
result.tool_calls.push_back({ result.tool_calls.push_back({
/* .name = */ action["tool_name"], /* .name = */ action.at("tool_name"),
/* .arguments = */ action["parameters"].dump(), /* .arguments = */ action.at("parameters").dump(),
/* .id = */ action["tool_call_id"], /* .id = */ action.at("tool_call_id"),
}); });
} }
} else if (std::regex_match(rest, match, response_regex)) {
auto response = match[1].str();
result.content += response;
} else { } else {
LOG_ERR("Failed to parse command_r output"); result.content += rest;
result.content = input;
} }
return result; return result;
} }
static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) { static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) { if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties"); throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
} }
const auto & parameters_properties = parameters.at("properties"); const auto & parameters_properties = parameters.at("properties");
@ -468,9 +502,9 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
}; };
foreach_function(inputs.tools, [&](const json & tool) { foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"]; const auto & function = tool.at("function");
std::string name = function["name"]; std::string name = function.at("name");
auto parameters = function["parameters"]; auto parameters = function.at("parameters");
builder.resolve_refs(parameters); builder.resolve_refs(parameters);
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime // https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
@ -546,34 +580,90 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
common_chat_params data; common_chat_params data;
data.grammar_lazy = inputs.tool_choice != "required"; if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != "required" && inputs.json_schema.is_null();
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) { foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"]; const auto & function = tool.at("function");
std::string name = function["name"]; std::string name = function.at("name");
auto parameters = function["parameters"]; auto parameters = function.at("parameters");
auto args_rule = builder.add_schema(name + "-args", parameters); auto args_rule = builder.add_schema(name + "-args", parameters);
tool_rules.push_back(builder.add_rule(name + "-call", tool_rules.push_back(builder.add_rule(name + "-call",
"\"<tool▁call▁begin>function<tool▁sep>" + name + "\\n```json\\n\" " + args_rule + " \"```<tool▁call▁end>\"")); "\"<tool▁call▁begin>function<tool▁sep>" + name + "\\n"
"```json\\n\" " + args_rule + " \"```<tool▁call▁end>\""));
}); });
// Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
// so we accept common variants (then it's all constrained)
builder.add_rule("root",
"( \"<tool▁calls▁begin>\" | \"<tool_calls_begin>\" | \"<tool calls begin>\" | \"<tool\\\\_calls\\\\_begin>\" ) "
"(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
"\"<tool▁calls▁end>\""
" space");
data.grammar_triggers.push_back({"<tool▁calls▁begin>", /* .at_start = */ false}); data.grammar_triggers.push_back({"<tool▁calls▁begin>", /* .at_start = */ false});
data.grammar_triggers.push_back({"<tool_calls_begin>", /* .at_start = */ false});
data.grammar_triggers.push_back({"<tool calls begin>", /* .at_start = */ false});
data.grammar_triggers.push_back({"<tool\\_calls\\_begin>", /* .at_start = */ false});
data.preserved_tokens = { data.preserved_tokens = {
"<think>",
"</think>",
"<tool▁sep>", "<tool▁sep>",
"<tool▁calls▁end",
"<tool▁call▁end>", "<tool▁call▁end>",
}; };
builder.add_rule("root", "\"<tool▁calls▁begin>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
}, grammar_options); }, grammar_options);
}
auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
// Hacks to fix the official (broken) prompt.
// It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead,
// until the official template is fixed.
if (tmpl.source().find("{% if ns.is_tool %}{{'<tool▁outputs▁end>'}}") != std::string::npos) {
// Don't leave the chat dangling after tool results
if (string_ends_with(prompt, "<tool▁outputs▁end>")) {
prompt += "<end▁of▁sentence>";
if (inputs.add_generation_prompt) {
prompt += "<Assistant>";
}
}
// Fix up tool call delta example added by Minja
prompt = std::regex_replace(
prompt,
std::regex("(<tool▁call▁end>)[\\s\\r\\n]*(<tool▁outputs▁begin>|<User>)"),
"$1<tool▁calls▁end><end▁of▁sentence>$2");
}
data.prompt = prompt; data.prompt = prompt;
data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1; data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING : COMMON_CHAT_FORMAT_DEEPSEEK_R1;
return data; return data;
} }
static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) { static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) {
static std::regex trigger_regex("<tool▁calls▁begin>");
static std::regex function_regex("<tool▁call▁begin>function<tool▁sep>([^\n]+)\n```json\n"); static std::regex function_regex("<tool▁call▁begin>function<tool▁sep>([^\n]+)\n```json\n");
static std::regex close_regex("```<tool▁call▁end>"); static std::regex close_regex("```[\\s\\r\\n]*<tool▁call▁end>");
return parse_json_tool_calls(input, trigger_regex, function_regex, close_regex); static std::regex reasoning_content_regex("((?:<think>)?([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)");
static std::regex tool_calls_regex("[\\s\\r\\n]*(?:<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>)([\\s\\S\\r\\n]*?)<tool▁calls▁end>");
common_chat_msg msg;
msg.role = "assistant";
std::smatch match;
if (std::regex_match(input, match, reasoning_content_regex)) {
std::string rest;
if (extract_reasoning) {
msg.reasoning_content = string_strip(match[2].str());
} else {
msg.content = match[1].str();
}
rest = match[3].str();
if (std::regex_search(rest, match, tool_calls_regex)) {
auto tool_calls = match[1].str();
auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex);
msg.tool_calls = std::move(msg2.tool_calls);
} else {
msg.content += std::string(rest.begin() + rest.find_first_not_of(" \r\n"), rest.end());
}
} else {
msg.content = input;
}
return msg;
} }
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
@ -583,20 +673,20 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
{"datetime", "Jan 29 2025 13:00:00 GMT"}, {"datetime", "Jan 29 2025 13:00:00 GMT"},
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
}); });
if (!inputs.tools.is_null() && !inputs.tools.empty()) { if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != "required"; data.grammar_lazy = inputs.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array(); auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) { foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"]; const auto & function = tool.at("function");
schemas.push_back({ schemas.push_back({
{"type", "object"}, {"type", "object"},
{"properties", { {"properties", {
{"name", { {"name", {
{"type", "string"}, {"type", "string"},
{"const", function["name"]}, {"const", function.at("name")},
}}, }},
{"arguments", function["parameters"]}, {"arguments", function.at("parameters")},
}}, }},
{"required", json::array({"name", "arguments", "id"})}, {"required", json::array({"name", "arguments", "id"})},
}); });
@ -628,15 +718,15 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
common_chat_params data; common_chat_params data;
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
if (!inputs.tools.is_null() && !inputs.tools.empty()) { if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != "required"; data.grammar_lazy = inputs.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> first_tool_rules; std::vector<std::string> first_tool_rules;
std::vector<std::string> subsequent_tool_rules; std::vector<std::string> subsequent_tool_rules;
foreach_function(inputs.tools, [&](const json & tool) { foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"]; const auto & function = tool.at("function");
std::string name = function["name"]; std::string name = function.at("name");
auto parameters = function["parameters"]; auto parameters = function.at("parameters");
auto args_rule = builder.add_schema(name + "-args", parameters); auto args_rule = builder.add_schema(name + "-args", parameters);
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule)); first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule)); subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
@ -716,9 +806,9 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) { foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"]; const auto & function = tool.at("function");
const auto & parameters = function["parameters"]; const auto & parameters = function.at("parameters");
std::string name = function["name"]; std::string name = function.at("name");
if (name == "python" || name == "ipython") { if (name == "python" || name == "ipython") {
if (!parameters.contains("type")) { if (!parameters.contains("type")) {
throw std::runtime_error("Missing type in python tool"); throw std::runtime_error("Missing type in python tool");
@ -789,9 +879,9 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) { foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"]; const auto & function = tool.at("function");
std::string name = function["name"]; std::string name = function.at("name");
auto parameters = function["parameters"]; auto parameters = function.at("parameters");
builder.resolve_refs(parameters); builder.resolve_refs(parameters);
tool_rules.push_back(builder.add_schema(name + "-call", { tool_rules.push_back(builder.add_schema(name + "-call", {
{"type", "object"}, {"type", "object"},
@ -839,9 +929,9 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
if (!parse_json(it, end, call)) { if (!parse_json(it, end, call)) {
throw std::runtime_error("Failed to parse json tool call"); throw std::runtime_error("Failed to parse json tool call");
} }
const auto & arguments = call["arguments"]; const auto & arguments = call.at("arguments");
result.tool_calls.push_back({ result.tool_calls.push_back({
call["name"], call.at("name"),
arguments.dump(), arguments.dump(),
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), // arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
/* id= */ "", /* id= */ "",
@ -884,47 +974,72 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
} }
common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
auto has_tools = !inputs.tools.is_null() && inputs.tool_choice != "none"; const auto & src = tmpl.source();
LOG_DBG("[%s] has_tools=%s\n", __func__, has_tools ? "true" : "false"); const auto & caps = tmpl.original_caps();
if (has_tools && !inputs.grammar.empty()) { if (inputs.tools.is_array()) {
if (inputs.tool_choice != "none" && !inputs.grammar.empty()) {
throw std::runtime_error("Cannot specify grammar with tools"); throw std::runtime_error("Cannot specify grammar with tools");
} }
if (caps.supports_tool_calls && !caps.supports_tools) {
LOG_WRN("Template supports tool calls but does not natively describe tools. The fallback behaviour used may produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
}
}
const auto & src = tmpl.source(); // DeepSeek R1: use handler in all cases except json schema (thinking / tools).
if (src.find("<tool▁calls▁begin>") != std::string::npos && inputs.json_schema.is_null()) {
return common_chat_params_init_deepseek_r1(tmpl, inputs);
}
// Command R7B: : use handler in all cases except json schema (thinking / tools).
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) {
return common_chat_params_init_command_r7b(tmpl, inputs);
}
// Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below.
if ((!inputs.tools.is_array() && inputs.json_schema.is_object())) {
return common_chat_params_init_generic(tmpl, inputs);
}
// Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases.
if (src.find(">>>all") != std::string::npos) { if (src.find(">>>all") != std::string::npos) {
// Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when
return common_chat_params_init_functionary_v3_2(tmpl, inputs); return common_chat_params_init_functionary_v3_2(tmpl, inputs);
} }
// Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases.
if (src.find(" functools[") != std::string::npos) { if (src.find(" functools[") != std::string::npos) {
// Firefunction v2 requires datetime and functions in the context, even w/o tools.
return common_chat_params_init_firefunction_v2(tmpl, inputs); return common_chat_params_init_firefunction_v2(tmpl, inputs);
} }
if (!has_tools) { // Plain handler (no tools)
if (inputs.tools.is_null() || inputs.tool_choice == "none") {
return common_chat_params_init_without_tools(tmpl, inputs); return common_chat_params_init_without_tools(tmpl, inputs);
} }
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
if (src.find("<tool_call>") != std::string::npos) { if (src.find("<tool_call>") != std::string::npos) {
return common_chat_params_init_hermes_2_pro(tmpl, inputs); return common_chat_params_init_hermes_2_pro(tmpl, inputs);
} }
// Functionary v3.1 (w/ tools)
if (src.find("<|start_header_id|>") != std::string::npos if (src.find("<|start_header_id|>") != std::string::npos
&& src.find("<function=") != std::string::npos) { && src.find("<function=") != std::string::npos) {
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, inputs); return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, inputs);
} }
// Llama 3.1, 3.2, 3.3 (w/ tools)
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) { if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos; auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools); return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
} }
if (src.find("<tool▁calls▁begin>") != std::string::npos) {
return common_chat_params_init_deepseek_r1(tmpl, inputs); // Mistral Nemo (w/ tools)
}
if (src.find("[TOOL_CALLS]") != std::string::npos) { if (src.find("[TOOL_CALLS]") != std::string::npos) {
return common_chat_params_init_mistral_nemo(tmpl, inputs); return common_chat_params_init_mistral_nemo(tmpl, inputs);
} }
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos) {
return common_chat_params_init_command_r7b(tmpl, inputs); // Generic fallback
}
return common_chat_params_init_generic(tmpl, inputs); return common_chat_params_init_generic(tmpl, inputs);
} }
@ -949,7 +1064,9 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
return common_chat_parse_llama_3_1(input, /* with_builtin_tools= */ true); return common_chat_parse_llama_3_1(input, /* with_builtin_tools= */ true);
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
return common_chat_parse_deepseek_r1(input); return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ false);
case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING:
return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ true);
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
return common_chat_parse_functionary_v3_2(input); return common_chat_parse_functionary_v3_2(input);
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
@ -959,7 +1076,9 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
return common_chat_parse_firefunction_v2(input); return common_chat_parse_firefunction_v2(input);
case COMMON_CHAT_FORMAT_COMMAND_R7B: case COMMON_CHAT_FORMAT_COMMAND_R7B:
return common_chat_parse_command_r7b(input); return common_chat_parse_command_r7b(input, /* extract_reasoning= */ false);
case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING:
return common_chat_parse_command_r7b(input, /* extract_reasoning= */ true);
default: default:
throw std::runtime_error("Unsupported format: " + common_chat_format_name(format)); throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
} }

View file

@ -19,6 +19,7 @@ struct common_chat_inputs {
bool stream; bool stream;
std::string grammar; std::string grammar;
bool add_generation_prompt = true; bool add_generation_prompt = true;
bool extract_reasoning = true;
}; };
enum common_chat_format { enum common_chat_format {
@ -28,11 +29,13 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_LLAMA_3_X, COMMON_CHAT_FORMAT_LLAMA_3_X,
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
COMMON_CHAT_FORMAT_DEEPSEEK_R1, COMMON_CHAT_FORMAT_DEEPSEEK_R1,
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_FIREFUNCTION_V2, COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO, COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B, COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
}; };

View file

@ -136,6 +136,7 @@ struct common_params_sampling {
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float top_n_sigma = -1.00f;// -1.0 = disabled
float mirostat_tau = 5.00f; // target entropy float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate float mirostat_eta = 0.10f; // learning rate
bool ignore_eos = false; bool ignore_eos = false;
@ -198,6 +199,11 @@ struct common_params_vocoder {
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
}; };
enum common_reasoning_format {
COMMON_REASONING_FORMAT_NONE,
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`
};
struct common_params { struct common_params {
int32_t n_predict = -1; // new tokens to predict int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 4096; // context size int32_t n_ctx = 4096; // context size
@ -288,6 +294,7 @@ struct common_params {
bool kl_divergence = false; // compute KL divergence bool kl_divergence = false; // compute KL divergence
bool usage = false; // print usage bool usage = false; // print usage
bool completion = false; // print source-able completion script
bool use_color = false; // use color to distinguish generations and inputs bool use_color = false; // use color to distinguish generations and inputs
bool special = false; // enable special token output bool special = false; // enable special token output
bool interactive = false; // interactive mode bool interactive = false; // interactive mode
@ -342,6 +349,7 @@ struct common_params {
std::string chat_template = ""; // NOLINT std::string chat_template = ""; // NOLINT
bool use_jinja = false; // NOLINT bool use_jinja = false; // NOLINT
bool enable_chat_template = true; bool enable_chat_template = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
std::vector<std::string> api_keys; std::vector<std::string> api_keys;
@ -420,13 +428,13 @@ bool set_process_priority(enum ggml_sched_priority prio);
// //
#ifdef __GNUC__ #ifdef __GNUC__
#ifdef __MINGW32__ # if defined(__MINGW32__) && !defined(__clang__)
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) # define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
# else
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
# endif
#else #else
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) # define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
#endif
#else
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
#endif #endif
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2) LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
@ -619,7 +627,7 @@ struct common_chat_msg {
std::string role; std::string role;
std::string content; std::string content;
std::vector<common_tool_call> tool_calls; std::vector<common_tool_call> tool_calls;
std::string tool_plan = ""; std::string reasoning_content = "";
}; };
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid

View file

@ -1,5 +1,6 @@
#include "log.h" #include "log.h"
#include <chrono>
#include <condition_variable> #include <condition_variable>
#include <cstdarg> #include <cstdarg>
#include <cstdio> #include <cstdio>

View file

@ -15,7 +15,7 @@
#ifndef __GNUC__ #ifndef __GNUC__
# define LOG_ATTRIBUTE_FORMAT(...) # define LOG_ATTRIBUTE_FORMAT(...)
#elif defined(__MINGW32__) #elif defined(__MINGW32__) && !defined(__clang__)
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) # define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
#else #else
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) # define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))

View file

@ -134,11 +134,11 @@ std::string common_params_sampling::print() const {
snprintf(result, sizeof(result), snprintf(result, sizeof(result),
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n" "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp, top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
mirostat, mirostat_eta, mirostat_tau); mirostat, mirostat_eta, mirostat_tau);
return std::string(result); return std::string(result);
@ -151,12 +151,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
lparams.no_perf = params.no_perf; lparams.no_perf = params.no_perf;
std::vector<const char *> trigger_words;
trigger_words.reserve(params.grammar_trigger_words.size());
for (const auto & str : params.grammar_trigger_words) {
trigger_words.push_back(str.word.c_str());
}
struct llama_sampler * grmr; struct llama_sampler * grmr;
if (params.grammar.compare(0, 11, "%llguidance") == 0) { if (params.grammar.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE #ifdef LLAMA_USE_LLGUIDANCE
@ -165,6 +159,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE #endif // LLAMA_USE_LLGUIDANCE
} else { } else {
std::vector<const char *> trigger_words;
trigger_words.reserve(params.grammar_trigger_words.size());
for (const auto & str : params.grammar_trigger_words) {
trigger_words.push_back(str.word.c_str());
}
grmr = params.grammar_lazy grmr = params.grammar_lazy
? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root", ? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root",
trigger_words.data(), trigger_words.size(), trigger_words.data(), trigger_words.size(),
@ -188,6 +188,11 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
params.logit_bias.data())); params.logit_bias.data()));
if (params.mirostat == 0) { if (params.mirostat == 0) {
if (params.top_n_sigma >= 0) {
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
} else {
for (const auto & cnstr : params.samplers) { for (const auto & cnstr : params.samplers) {
switch (cnstr) { switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY: case COMMON_SAMPLER_TYPE_DRY:
@ -229,6 +234,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
GGML_ASSERT(false && "unknown sampler type"); GGML_ASSERT(false && "unknown sampler type");
} }
} }
}
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
} else if (params.mirostat == 1) { } else if (params.mirostat == 1) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));

Binary file not shown.

View file

@ -173,6 +173,7 @@ struct slot_params {
{"grammar_trigger_words", grammar_trigger_words}, {"grammar_trigger_words", grammar_trigger_words},
{"grammar_trigger_tokens", sampling.grammar_trigger_tokens}, {"grammar_trigger_tokens", sampling.grammar_trigger_tokens},
{"preserved_tokens", sampling.preserved_tokens}, {"preserved_tokens", sampling.preserved_tokens},
{"chat_format", common_chat_format_name(oaicompat_chat_format)},
{"samplers", samplers}, {"samplers", samplers},
{"speculative.n_max", speculative.n_max}, {"speculative.n_max", speculative.n_max},
{"speculative.n_min", speculative.n_min}, {"speculative.n_min", speculative.n_min},
@ -724,9 +725,19 @@ struct server_task_result_cmpl_final : server_task_result {
msg.content = content; msg.content = content;
} }
json tool_calls; json message {
{"role", "assistant"},
};
if (!msg.reasoning_content.empty()) {
message["reasoning_content"] = msg.reasoning_content;
}
if (msg.content.empty() && !msg.tool_calls.empty()) {
message["content"] = json();
} else {
message["content"] = msg.content;
}
if (!msg.tool_calls.empty()) { if (!msg.tool_calls.empty()) {
tool_calls = json::array(); auto tool_calls = json::array();
for (const auto & tc : msg.tool_calls) { for (const auto & tc : msg.tool_calls) {
tool_calls.push_back({ tool_calls.push_back({
{"type", "function"}, {"type", "function"},
@ -737,15 +748,7 @@ struct server_task_result_cmpl_final : server_task_result {
{"id", tc.id}, {"id", tc.id},
}); });
} }
} message["tool_calls"] = tool_calls;
json message {
{"content", msg.content},
{"tool_calls", tool_calls},
{"role", "assistant"},
};
if (!msg.tool_plan.empty()) {
message["tool_plan"] = msg.tool_plan;
} }
json choice { json choice {
@ -2073,8 +2076,8 @@ struct server_context {
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
// Might be better to reject the request with a 400 ? // Might be better to reject the request with a 400 ?
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.params.n_predict, slot.n_predict);
slot.params.n_predict = slot.n_predict; slot.params.n_predict = slot.n_predict;
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
} }
if (slot.params.ignore_eos && has_eos_token) { if (slot.params.ignore_eos && has_eos_token) {
@ -4060,7 +4063,7 @@ int main(int argc, char ** argv) {
} }
auto body = json::parse(req.body); auto body = json::parse(req.body);
json data = oaicompat_completion_params_parse(body, params.use_jinja, ctx_server.chat_templates); json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
return handle_completions_impl( return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION, SERVER_TASK_TYPE_COMPLETION,
@ -4073,7 +4076,7 @@ int main(int argc, char ** argv) {
// same with handle_chat_completions, but without inference part // same with handle_chat_completions, but without inference part
const auto handle_apply_template = [&ctx_server, &params, &res_ok](const httplib::Request & req, httplib::Response & res) { const auto handle_apply_template = [&ctx_server, &params, &res_ok](const httplib::Request & req, httplib::Response & res) {
auto body = json::parse(req.body); auto body = json::parse(req.body);
json data = oaicompat_completion_params_parse(body, params.use_jinja, ctx_server.chat_templates); json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
}; };

View file

@ -92,6 +92,7 @@ def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, a
tool_calls = choice["message"].get("tool_calls") tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0] tool_call = tool_calls[0]
assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}'
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
assert expected_function_name == tool_call["function"]["name"] assert expected_function_name == tool_call["function"]["name"]
actual_arguments = tool_call["function"]["arguments"] actual_arguments = tool_call["function"]["arguments"]
@ -155,11 +156,11 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
(TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), (TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), # (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), (TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), # (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
@ -175,7 +176,7 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
(TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), (TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), # (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),
# TODO: fix these # TODO: fix these
# (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), # (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
# (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), # (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
@ -214,6 +215,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
tool_calls = choice["message"].get("tool_calls") tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0] tool_call = tool_calls[0]
assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}'
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
assert expected_function_name == tool_call["function"]["name"] assert expected_function_name == tool_call["function"]["name"]
actual_arguments = tool_call["function"]["arguments"] actual_arguments = tool_call["function"]["arguments"]
@ -273,7 +275,6 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize("hf_repo,template_override", [ @pytest.mark.parametrize("hf_repo,template_override", [
("bartowski/c4ai-command-r7b-12-2024-GGUF:Q4_K_M", ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use")),
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
@ -298,13 +299,16 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L", ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use")),
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it. # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
# ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), # ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
# ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
]) ])
def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None): def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
global server global server
n_predict = 512 n_predict = 512
server.n_slots = 1 server.n_slots = 1
@ -323,6 +327,7 @@ def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None)
res = server.make_request("POST", "/chat/completions", data={ res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predict, "max_tokens": n_predict,
"messages": [ "messages": [
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
{"role": "user", "content": "What is the weather in Istanbul?"}, {"role": "user", "content": "What is the weather in Istanbul?"},
], ],
"tools": [WEATHER_TOOL], "tools": [WEATHER_TOOL],
@ -332,6 +337,7 @@ def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None)
tool_calls = choice["message"].get("tool_calls") tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0] tool_call = tool_calls[0]
assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}'
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"] assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"]
actual_arguments = json.loads(tool_call["function"]["arguments"]) actual_arguments = json.loads(tool_call["function"]["arguments"])
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}" assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
@ -340,22 +346,166 @@ def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None)
assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}' assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
@pytest.mark.slow
@pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
(None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
(None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
(None, 128, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
(None, 128, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
(None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
(None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
("^> 0.56$", 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
# TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
("^The y-coordinate [\\s\\S]*?\\*\\*0.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
("[\\s\\S]*?\\*\\*0\\.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
])
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
global server
# n_predict = 512
server.n_slots = 1
server.jinja = True
server.n_ctx = 8192 * 2
server.n_predict = n_predict
server.model_hf_repo = hf_repo
server.model_hf_file = None
if isinstance(template_override, tuple):
(template_hf_repo, template_variant) = template_override
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predict,
"messages": [
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things, and provide very concise answers. Do not explain your reasoning to the user. Provide any numerical values back to the user with at most two decimals."},
{"role": "user", "content": "What's the y coordinate of a point on the unit sphere at angle 30 degrees?"},
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_6789",
"type": "function",
"function": {
"name": "calculate",
"arguments": "{\"expression\":\"sin(30 * pi / 180)\"}"
}
}
]
},
{
"role": "tool",
"name": "calculate",
"content": 0.55644242476,
"tool_call_id": "call_6789"
}
],
"tools": [
{
"type":"function",
"function":{
"name":"calculate",
"description":"A calculator function that computes values of arithmetic expressions in the Python syntax",
"parameters":{
"type":"object",
"properties":{
"expression":{
"type":"string",
"description":"An arithmetic expression to compute the value of (Python syntad, assuming all floats)"
}
},
"required":["expression"]
}
}
}
]
}, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls is None, f'Expected no tool call in {choice["message"]}'
content = choice["message"].get("content")
assert content is not None, f'Expected content in {choice["message"]}'
if result_override is not None:
assert re.match(result_override, content), f'Expected {result_override}, got {content}'
else:
assert re.match('^[\\s\\S]*?The (y[ -])?coordinate [\\s\\S]*?is (approximately )?0\\.56\\b|^0\\.56$', content), \
f'Expected something like "The y coordinate is 0.56.", got {content}'
@pytest.mark.slow
@pytest.mark.parametrize("n_predict,reasoning_format,expect_content,expect_reasoning_content,hf_repo,template_override", [
(128, 'deepseek', "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(128, None, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(1024, 'deepseek', "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'none', "<think>\n?I need[\\s\\S]*?</think>\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'deepseek', "To find the sum of.*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
])
def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
global server
server.n_slots = 1
server.reasoning_format = reasoning_format
server.jinja = True
server.n_ctx = 8192 * 2
server.n_predict = n_predict
server.model_hf_repo = hf_repo
server.model_hf_file = None
if isinstance(template_override, tuple):
(template_hf_repo, template_variant) = template_override
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predict,
"messages": [
{"role": "user", "content": "What's the sum of 102 and 7?"},
]
}, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
content = choice["message"].get("content")
if expect_content is None:
assert content is None, f'Expected no content in {choice["message"]}'
else:
assert re.match(expect_content, content), f'Expected {expect_content}, got {content}'
reasoning_content = choice["message"].get("reasoning_content")
if expect_reasoning_content is None:
assert reasoning_content is None, f'Expected no reasoning content in {choice["message"]}'
else:
assert re.match(expect_reasoning_content, reasoning_content), f'Expected {expect_reasoning_content}, got {reasoning_content}'
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [ @pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [
(None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
# (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", "chatml"),
(None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), (None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), (None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
(None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)), (None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)),
(None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"), (None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
(None, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), ('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), (None, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
('{"code":"print("}', "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), (None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), (None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), ('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), (None, "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), (None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), (None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
@ -371,15 +521,13 @@ def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None)
# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it. # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
(None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), (None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
# (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
]) ])
def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None): def test_hello_world(expected_arguments_override: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
global server global server
server.n_slots = 1 server.n_slots = 1
server.jinja = True server.jinja = True
server.n_ctx = 8192 server.n_ctx = 8192
server.n_predict = 128 server.n_predict = 512 # High because of DeepSeek R1
server.model_hf_repo = hf_repo server.model_hf_repo = hf_repo
server.model_hf_file = None server.model_hf_file = None
if isinstance(template_override, tuple): if isinstance(template_override, tuple):
@ -406,6 +554,7 @@ def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo:
tool_calls = choice["message"].get("tool_calls") tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0] tool_call = tool_calls[0]
assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}'
assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"] assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
actual_arguments = tool_call["function"]["arguments"] actual_arguments = tool_call["function"]["arguments"]
if expected_arguments_override is not None: if expected_arguments_override is not None:

View file

@ -78,6 +78,7 @@ class ServerProcess:
draft_max: int | None = None draft_max: int | None = None
no_webui: bool | None = None no_webui: bool | None = None
jinja: bool | None = None jinja: bool | None = None
reasoning_format: Literal['deepseek', 'none'] | None = None
chat_template: str | None = None chat_template: str | None = None
chat_template_file: str | None = None chat_template_file: str | None = None
@ -172,6 +173,8 @@ class ServerProcess:
server_args.append("--no-webui") server_args.append("--no-webui")
if self.jinja: if self.jinja:
server_args.append("--jinja") server_args.append("--jinja")
if self.reasoning_format is not None:
server_args.extend(("--reasoning-format", self.reasoning_format))
if self.chat_template: if self.chat_template:
server_args.extend(["--chat-template", self.chat_template]) server_args.extend(["--chat-template", self.chat_template])
if self.chat_template_file: if self.chat_template_file:

View file

@ -578,6 +578,7 @@ static json oaicompat_completion_params_parse(const json & body) {
static json oaicompat_completion_params_parse( static json oaicompat_completion_params_parse(
const json & body, /* openai api json semantics */ const json & body, /* openai api json semantics */
bool use_jinja, bool use_jinja,
common_reasoning_format reasoning_format,
const common_chat_templates & chat_templates) const common_chat_templates & chat_templates)
{ {
json llama_params; json llama_params;
@ -633,6 +634,7 @@ static json oaicompat_completion_params_parse(
throw std::runtime_error("Cannot use custom grammar constraints with tools."); throw std::runtime_error("Cannot use custom grammar constraints with tools.");
} }
common_chat_inputs inputs; common_chat_inputs inputs;
inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
inputs.messages = body.at("messages"); inputs.messages = body.at("messages");
inputs.tools = tools; inputs.tools = tools;
inputs.tool_choice = tool_choice; inputs.tool_choice = tool_choice;

View file

@ -254,12 +254,12 @@ export default function ChatMessage({
🔄 Regenerate 🔄 Regenerate
</button> </button>
)} )}
</>
)}
<CopyButton <CopyButton
className="badge btn-mini show-on-hover mr-2" className="badge btn-mini show-on-hover mr-2"
content={msg.content} content={msg.content}
/> />
</>
)}
</div> </div>
)} )}
</div> </div>

View file

@ -198,7 +198,7 @@
#ifndef __GNUC__ #ifndef __GNUC__
# define GGML_ATTRIBUTE_FORMAT(...) # define GGML_ATTRIBUTE_FORMAT(...)
#elif defined(__MINGW32__) #elif defined(__MINGW32__) && !defined(__clang__)
# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) # define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
#else #else
# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) # define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))

File diff suppressed because it is too large Load diff

View file

@ -7,7 +7,6 @@
#include "ggml-cpu-impl.h" #include "ggml-cpu-impl.h"
#include "ggml-cpu.h" #include "ggml-cpu.h"
#include "ggml-impl.h" #include "ggml-impl.h"
#include "ggml-quants.h"
#include "ggml-cpu-quants.h" #include "ggml-cpu-quants.h"
#include "ggml-threading.h" #include "ggml-threading.h"
// #include "amx/amx.h" // #include "amx/amx.h"
@ -1295,7 +1294,7 @@ struct ggml_threadpool {
atomic_int n_graph; // incremented when there is work to be done (i.e each graph) atomic_int n_graph; // incremented when there is work to be done (i.e each graph)
atomic_int GGML_CACHE_ALIGN n_barrier; atomic_int GGML_CACHE_ALIGN n_barrier;
atomic_int GGML_CACHE_ALIGN n_barrier_passed; atomic_int GGML_CACHE_ALIGN n_barrier_passed;
atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads. atomic_int GGML_CACHE_ALIGN current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
// these are atomic as an annotation for thread-sanitizer // these are atomic as an annotation for thread-sanitizer
atomic_bool stop; // Used for stopping the threadpool altogether atomic_bool stop; // Used for stopping the threadpool altogether
@ -7528,6 +7527,7 @@ UseGgmlGemm1:;
if (src1->type != vec_dot_type) { if (src1->type != vec_dot_type) {
char * wdata = params->wdata; char * wdata = params->wdata;
const size_t nbw0 = ggml_type_size(vec_dot_type);
const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
const size_t nbw2 = nbw1*ne11; const size_t nbw2 = nbw1*ne11;
const size_t nbw3 = nbw2*ne12; const size_t nbw3 = nbw2*ne12;
@ -7535,6 +7535,7 @@ UseGgmlGemm1:;
assert(params->wsize >= ne13*nbw3); assert(params->wsize >= ne13*nbw3);
GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32);
#if 0
for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = ith; i11 < ne11; i11 += nth) { for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
@ -7544,6 +7545,20 @@ UseGgmlGemm1:;
} }
} }
} }
#else
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
size_t bs = ggml_blck_size(vec_dot_type);
int64_t ne10_block_start = (ith * ne10/bs) / nth;
int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth;
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
(ne10_block_end - ne10_block_start) * bs);
}
}
}
#endif
} }
if (ith == 0) { if (ith == 0) {
@ -7631,7 +7646,6 @@ UseGgmlGemm2:;
if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) { if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
num_rows_per_vec_dot = 1; num_rows_per_vec_dot = 1;
} }
ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
if (nth >= nchunk0 * nchunk1) { if (nth >= nchunk0 * nchunk1) {
@ -7644,144 +7658,44 @@ UseGgmlGemm2:;
// ggml_compute_forward_mul_mat_id // ggml_compute_forward_mul_mat_id
static void ggml_compute_forward_mul_mat_id( #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ids->ne[0]*ids->ne[1] + (i1)]
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; struct mmid_row_mapping {
const struct ggml_tensor * src1 = dst->src[1]; int32_t i1;
const struct ggml_tensor * ids = dst->src[2]; int32_t i2;
};
static void ggml_compute_forward_mul_mat_id_one_chunk(
struct ggml_tensor * dst,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
const struct ggml_tensor * ids,
const int64_t cur_a,
const int64_t ir0_start,
const int64_t ir0_end,
const int64_t ir1_start,
const int64_t ir1_end,
const char * src0_cur,
const struct mmid_row_mapping * matrix_rows,
const size_t row_size,
const bool src1_cont,
const void * wdata) {
GGML_TENSOR_BINARY_OP_LOCALS GGML_TENSOR_BINARY_OP_LOCALS
const int ith = params->ith;
const int nth = params->nth;
const enum ggml_type type = src0->type; const enum ggml_type type = src0->type;
const bool src1_cont = ggml_is_contiguous(src1);
ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot; ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type; enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(type));
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
// row groups
const int n_ids = ids->ne[0]; // n_expert_used
const int n_as = ne02; // n_expert
char * wdata_src1_end = (src1->type == vec_dot_type) ?
(char *) params->wdata :
(char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
struct mmid_row_mapping {
int32_t i1;
int32_t i2;
};
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
if (src1->type != vec_dot_type) {
char * wdata = params->wdata;
const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
const size_t nbw2 = nbw1*ne11;
const size_t nbw3 = nbw2*ne12;
assert(params->wsize >= ne13*nbw3);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
ne10);
}
}
}
}
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
if (ith == 0) {
// initialize matrix_row_counts
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
// group rows by src0 matrix
for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
for (int id = 0; id < n_ids; ++id) {
const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
assert(i02 >= 0 && i02 < n_as);
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
matrix_row_counts[i02] += 1;
}
}
}
ggml_barrier(params->threadpool);
// compute each matrix multiplication in sequence
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
const int64_t cne1 = matrix_row_counts[cur_a];
if (cne1 == 0) {
continue;
}
const char * src0_cur = (const char *) src0->data + cur_a*nb02;
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = cne1; // src1 rows
// distribute the thread work across the inner or outer loop based on which one is larger
const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
const int64_t ith0 = ith % nth0;
const int64_t ith1 = ith / nth0;
const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
const int64_t ir010 = dr0*ith0;
const int64_t ir011 = MIN(ir010 + dr0, nr0);
const int64_t ir110 = dr1*ith1;
const int64_t ir111 = MIN(ir110 + dr1, nr1);
// threads with no work simply yield (not sure if it helps)
//if (ir010 >= ir011 || ir110 >= ir111) {
// sched_yield();
// continue;
//}
// block-tiling attempt
const int64_t blck_0 = 16; const int64_t blck_0 = 16;
const int64_t blck_1 = 16; const int64_t blck_1 = 16;
// attempt to reduce false-sharing (does not seem to make a difference)
float tmp[16]; float tmp[16];
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ++ir1) {
const int64_t _i12 = ir1; // logical row index for this expert const int64_t _i12 = ir1; // logical row index for this expert
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12); struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
@ -7804,21 +7718,202 @@ static void ggml_compute_forward_mul_mat_id(
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2)); float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
//}
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1); vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
} }
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir0_end) - iir0)*sizeof(float));
} }
} }
} }
}
static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
void * ptr = *p;
ptr = (void *) GGML_PAD((uintptr_t) ptr, align);
*p = (void *) ((char *) ptr + size);
return ptr;
}
static void ggml_compute_forward_mul_mat_id(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
const struct ggml_tensor * ids = dst->src[2];
GGML_TENSOR_BINARY_OP_LOCALS
const int ith = params->ith;
const int nth = params->nth;
const enum ggml_type type = src0->type;
const bool src1_cont = ggml_is_contiguous(src1);
enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(type));
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
// row groups
const int n_ids = ids->ne[0]; // n_expert_used
const int n_as = ne02; // n_expert
void * wdata_cur = params->wdata;
if (src1->type != vec_dot_type) {
incr_ptr_aligned(&wdata_cur, ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
}
int64_t * matrix_row_counts = // [n_as]
incr_ptr_aligned(&wdata_cur, n_as*sizeof(int64_t), sizeof(int64_t));
struct mmid_row_mapping * matrix_rows = // [n_as][ids->ne[0]*ids->ne[1]]
incr_ptr_aligned(&wdata_cur, n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping), sizeof(int64_t));
char (*atomic_current_chunk)[CACHE_LINE_SIZE] = // [n_as]
incr_ptr_aligned(&wdata_cur, CACHE_LINE_SIZE * n_as, CACHE_LINE_SIZE);
GGML_ASSERT(params->wsize >= (size_t)((char *) wdata_cur - (char *) params->wdata));
if (src1->type != vec_dot_type) {
char * wdata = params->wdata;
const size_t nbw0 = ggml_type_size(vec_dot_type);
const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
const size_t nbw2 = nbw1*ne11;
const size_t nbw3 = nbw2*ne12;
assert(params->wsize >= ne13*nbw3);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
#if 0
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = ith; i12 < ne12; i12 += nth) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
ne10);
}
}
}
#else
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
size_t bs = ggml_blck_size(vec_dot_type);
int64_t ne10_block_start = (ith * ne10/bs) / nth;
int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth;
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
(ne10_block_end - ne10_block_start) * bs);
}
}
}
#endif
}
if (ith == 0) {
// initialize matrix_row_counts
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
// group rows by src0 matrix
for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
for (int id = 0; id < n_ids; ++id) {
const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
assert(i02 >= 0 && i02 < n_as);
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
matrix_row_counts[i02] += 1;
}
}
} }
#undef MMID_MATRIX_ROW // reset current_chunk
for (int cur_a = ith; cur_a < n_as; cur_a += nth) {
atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
*current_chunk_ctr = nth;
}
ggml_barrier(params->threadpool);
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
const int64_t cne1 = matrix_row_counts[cur_a];
if (cne1 == 0) {
continue;
}
const char * src0_cur = (const char *) src0->data + cur_a * nb02;
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
const int64_t nr0 = ne01;
const int64_t nr1 = cne1;
int chunk_size = 16;
if (nr0 == 1 || nr1 == 1) {
chunk_size = 64;
}
#if defined(__aarch64__)
// disable for ARM
const bool disable_chunking = true;
#else
// disable for NUMA
const bool disable_chunking = ggml_is_numa();
#endif // defined(__aarch64__)
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
if (nchunk0 * nchunk1 < nth * 4 || disable_chunking) {
nchunk0 = nr0 > nr1 ? nth : 1;
nchunk1 = nr0 > nr1 ? 1 : nth;
}
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
int current_chunk = ith;
atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
while (current_chunk < nchunk0 * nchunk1) {
const int64_t ith0 = current_chunk % nchunk0;
const int64_t ith1 = current_chunk / nchunk0;
const int64_t ir0_start = dr0 * ith0;
const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
const int64_t ir1_start = dr1 * ith1;
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
ggml_compute_forward_mul_mat_id_one_chunk(
dst, src0, src1, ids, cur_a,
ir0_start, ir0_end, ir1_start, ir1_end,
src0_cur, matrix_rows, row_size, src1_cont, wdata
);
if (nth >= nchunk0 * nchunk1) {
break;
}
current_chunk = atomic_fetch_add_explicit(current_chunk_ctr, 1, memory_order_relaxed);
}
}
} }
// ggml_compute_forward_out_prod // ggml_compute_forward_out_prod
@ -9112,10 +9207,6 @@ static void ggml_compute_forward_clamp_f32(
const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src0 = dst->src[0];
if (params->ith != 0) {
return;
}
float min; float min;
float max; float max;
memcpy(&min, (float *) dst->op_params + 0, sizeof(float)); memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
@ -13761,14 +13852,19 @@ struct ggml_cplan ggml_graph_plan(
cur = 0; cur = 0;
const struct ggml_tensor * src0 = node->src[0]; const struct ggml_tensor * src0 = node->src[0];
const struct ggml_tensor * src1 = node->src[1]; const struct ggml_tensor * src1 = node->src[1];
const struct ggml_tensor * ids = node->src[2];
const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type; const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
if (src1->type != vec_dot_type) {
cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
}
const int n_as = src0->ne[2]; const int n_as = src0->ne[2];
cur += GGML_PAD(cur, sizeof(int64_t)); // align // src1
cur += n_as * sizeof(int64_t); // matrix_row_counts if (src1->type != vec_dot_type) {
cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)) + sizeof(int64_t);
}
// matrix_row_counts
cur += n_as * sizeof(int64_t) + sizeof(int64_t);
// matrix_rows
cur += n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping) + sizeof(int64_t);
// atomic_current_chunk
cur += CACHE_LINE_SIZE*n_as + CACHE_LINE_SIZE;
} break; } break;
case GGML_OP_OUT_PROD: case GGML_OP_OUT_PROD:
{ {

View file

@ -280,14 +280,6 @@ template <> inline __m256bh load(const float *p) {
} }
#endif #endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// CONSTANTS
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
#endif
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
// FLOATING POINT MATRIX MULTIPLICATION // FLOATING POINT MATRIX MULTIPLICATION
@ -614,6 +606,14 @@ class tinyBLAS_Q0_AVX {
TC *C, int64_t ldc, TC *C, int64_t ldc,
int ith, int nth) int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
const int8_t kvalues_iq4nl[16] = {
-127, -104, -83, -65,
-49, -35, -22, -10,
1, 13, 25, 38,
53, 69, 89, 113
};
iq4nlt = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
} }
void matmul(int64_t m, int64_t n) { void matmul(int64_t m, int64_t n) {
@ -1038,6 +1038,7 @@ class tinyBLAS_Q0_AVX {
const int64_t ldc; const int64_t ldc;
const int ith; const int ith;
const int nth; const int nth;
__m128i iq4nlt;
}; };
#endif // __AVX__ #endif // __AVX__

View file

@ -180,11 +180,11 @@ static ggml_cuda_device_info ggml_cuda_init() {
int major_version = 0; int major_version = 0;
size_t version_length = 0; size_t version_length = 0;
if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) { if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) {
std::string version(version_length, '\0'); std::vector<char> version(version_length+1, '\0');
if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) { if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) {
version.resize(::strlen(version.c_str())); version.resize(::strlen(version.data()));
int parsed_value = 0; int parsed_value = 0;
if (std::from_chars(version.c_str(), version.c_str() + version.length(), parsed_value).ec == std::errc()) { if (std::from_chars(version.data(), version.data() + version.size(), parsed_value).ec == std::errc()) {
major_version = parsed_value; major_version = parsed_value;
} }
} }
@ -1481,12 +1481,7 @@ static void ggml_cuda_op_mul_mat(
const size_t nbytes_data = ggml_nbytes(src0); const size_t nbytes_data = ggml_nbytes(src0);
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING); const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding); dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
// TODO: remove this for MUSA once the Guilty Lockup issue is resolved
#ifndef GGML_USE_MUSA
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream)); CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream));
#else // GGML_USE_MUSA
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream));
#endif // !GGML_USE_MUSA
} }
// If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared: // If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:

View file

@ -151,5 +151,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
} }
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc) && !GGML_CUDA_CC_IS_GCN(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
} }

View file

@ -1434,6 +1434,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
// some shaders have a minimum subgroup size // some shaders have a minimum subgroup size
const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u);
const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u); const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u);
@ -1496,13 +1497,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1; const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1; const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size }; l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size }; m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size }; s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size }; l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size }; m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size }; s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 }; m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };

View file

@ -1174,6 +1174,9 @@ extern "C" {
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed); LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
/// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641
LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(float n);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.

View file

@ -0,0 +1,22 @@
These templates can be updated with the following commands:
```bash
./scripts/get_chat_template.py CohereForAI/c4ai-command-r-plus tool_use > models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja
./scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 default > models/templates/CohereForAI-c4ai-command-r7b-12-2024-default.jinja
./scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 rag > models/templates/CohereForAI-c4ai-command-r7b-12-2024-rag.jinja
./scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use > models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja
./scripts/get_chat_template.py deepseek-ai/DeepSeek-R1-Distill-Llama-8B > models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja
./scripts/get_chat_template.py deepseek-ai/DeepSeek-R1-Distill-Qwen-32B > models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja
./scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 > models/templates/fireworks-ai-llama-3-firefunction-v2.jinja
./scripts/get_chat_template.py google/gemma-2-2b-it > models/templates/google-gemma-2-2b-it.jinja
./scripts/get_chat_template.py meetkai/functionary-medium-v3. > models/templates/meetkai-functionary-medium-v3.jinja
./scripts/get_chat_template.py meetkai/functionary-medium-v3.2 > models/templates/meetkai-functionary-medium-v3.2.jinja
./scripts/get_chat_template.py meta-llama/Llama-3.1-8B-Instruct > models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja
./scripts/get_chat_template.py meta-llama/Llama-3.2-3B-Instruct > models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja
./scripts/get_chat_template.py meta-llama/Llama-3.3-70B-Instruct > models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja
./scripts/get_chat_template.py microsoft/Phi-3.5-mini-instruct > models/templates/microsoft-Phi-3.5-mini-instruct.jinja
./scripts/get_chat_template.py mistralai/Mistral-Nemo-Instruct-2407 > models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja
./scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use > models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja
./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use > models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja
./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja
```

View file

@ -0,0 +1,76 @@
{%- if not add_generation_prompt is defined -%}
{%- set add_generation_prompt = false -%}
{%- endif -%}
{%- set ns = namespace(is_first=false, is_tool_outputs=false, is_output_first=true, system_prompt='') -%}
{%- for message in messages -%}
{%- if message['role'] == 'system' -%}
{%- set ns.system_prompt = message['content'] -%}
{%- endif -%}
{%- endfor -%}
{{bos_token}}
{%- if tools %}
You can call any of the following function tools to satisfy the user's requests: {{tools | map(attribute='function') | tojson(indent=2)}}
Example function tool call syntax:
<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>example_function_name
```json
{
"arg1": "some_value"
...
}
```
<tool▁call▁end><tool▁calls▁end>
{% endif -%}
{{ns.system_prompt}}
{%- macro flush_tool_outputs() -%}
{%- if ns.is_tool_outputs -%}
{{- '<tool▁outputs▁end><end▁of▁sentence>' -}}
{%- set ns.is_tool_outputs = false -%}
{%- endif -%}
{%- endmacro -%}
{{- flush_tool_outputs() -}}
{%- for message in messages -%}
{%- if message['role'] != 'tool' -%}
{{- flush_tool_outputs() -}}
{%- endif -%}
{%- if message['role'] == 'user' -%}
{{- '<User>' + message['content'] + '<end▁of▁sentence>' -}}
{%- endif -%}
{%- if message['role'] == 'assistant' and message['content'] is none -%}
{{- '<Assistant><tool▁calls▁begin>' -}}
{%- set ns.is_first = true -%}
{%- for tc in message['tool_calls'] -%}
{%- if ns.is_first -%}
{%- set ns.is_first = false -%}
{%- else -%}
{{- '\n' -}}
{%- endif -%}
{%- set tool_name = tc['function']['name'] -%}
{%- set tool_args = tc['function']['arguments'] -%}
{{- '<tool▁call▁begin>' + tc['type'] + '<tool▁sep>' + tool_name + '\n' + '```json' + '\n' + tool_args + '\n' + '```' + '<tool▁call▁end>' -}}
{%- endfor -%}
{{- '<tool▁calls▁end><end▁of▁sentence>' -}}
{%- endif -%}
{%- if message['role'] == 'assistant' and message['content'] is not none -%}
{{- flush_tool_outputs() -}}
{%- set content = message['content'] -%}
{%- if '</think>' in content -%}
{%- set content = content.split('</think>')[-1] -%}
{%- endif -%}
{{- '<Assistant>' + content + '<end▁of▁sentence>' -}}
{%- endif -%}
{%- if message['role'] == 'tool' -%}
{%- set ns.is_tool_outputs = true -%}
{%- if ns.is_output_first -%}
{{- '<tool▁outputs▁begin>' -}}
{%- set ns.is_output_first = false -%}
{%- endif -%}
{{- '\n<tool▁output▁begin>' + message['content'] + '<tool▁output▁end>' -}}
{%- endif -%}
{%- endfor -%}
{{- flush_tool_outputs() -}}
{%- if add_generation_prompt and not ns.is_tool_outputs -%}
{{- '<Assistant><think>\n' -}}
{%- endif -%}

View file

@ -1186,7 +1186,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
return; return;
} }
} }
LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`) (buffer: `%s`)\n", token, piece.c_str(), grammar.trigger_buffer.c_str()); LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str());
return; return;
} }
} }

View file

@ -6,13 +6,13 @@
#include <vector> #include <vector>
#ifdef __GNUC__ #ifdef __GNUC__
#ifdef __MINGW32__ # if defined(__MINGW32__) && !defined(__clang__)
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) # define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
# else
# define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
# endif
#else #else
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) # define LLAMA_ATTRIBUTE_FORMAT(...)
#endif
#else
#define LLAMA_ATTRIBUTE_FORMAT(...)
#endif #endif
// //

View file

@ -37,7 +37,7 @@ struct llama_kv_cache {
bool can_shift = false; bool can_shift = false;
// Note: The value of head isn't only used to optimize searching // Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_internal also uses it, so it // for a free KV slot. llama_decode_impl also uses it, so it
// cannot be freely changed after a slot has been allocated. // cannot be freely changed after a slot has been allocated.
uint32_t head = 0; uint32_t head = 0;
uint32_t size = 0; uint32_t size = 0;

View file

@ -1698,6 +1698,73 @@ struct llama_sampler * llama_sampler_init_penalties(
); );
} }
// top-n-sigma
struct llama_sampler_top_n_sigma {
const float n;
};
static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
return "top-n-sigma";
}
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
// find max logit and calculate mean
float max = cur_p->data[0].logit;
float logits_sum = 0;
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].logit > max) {
max = cur_p->data[i].logit;
}
logits_sum += cur_p->data[i].logit;
}
float mean = logits_sum/cur_p->size;
// calculate standard deviation
float acc = 0;
for (size_t i = 0; i < cur_p->size; ++i) {
acc += pow(cur_p->data[i].logit - mean, 2);
}
float std = sqrt(acc/cur_p->size);
//apply mask
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].logit < max - (ctx->n * std)) {
cur_p->data[i].logit = -INFINITY;
}
}
llama_sampler_softmax_impl(cur_p);
}
static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx;
return llama_sampler_init_top_n_sigma(ctx->n);
}
static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
delete (llama_sampler_top_n_sigma *) smpl->ctx;
}
static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
/* .name = */ llama_sampler_top_n_sigma_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_top_n_sigma_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_top_n_sigma_clone,
/* .free = */ llama_sampler_top_n_sigma_free,
};
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
return llama_sampler_init(
/* .iface = */ &llama_sampler_top_n_sigma_i,
/* .ctx = */ new llama_sampler_top_n_sigma {
/* .n = */ n,
}
);
}
// DRY // DRY
struct llama_sampler_dry { struct llama_sampler_dry {