mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-06 08:01:27 +00:00
Merge commit 'c945aaaef2' into concedo_experimental
# Conflicts: # .devops/cann.Dockerfile # .github/workflows/build.yml # .github/workflows/release.yml # README.md # common/CMakeLists.txt # common/chat.cpp # docs/function-calling.md # ggml/src/ggml-cann/aclnn_ops.cpp # ggml/src/ggml-cann/aclnn_ops.h # ggml/src/ggml-cann/common.h # ggml/src/ggml-cann/ggml-cann.cpp # models/templates/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.jinja # scripts/sync_vendor.py # tests/CMakeLists.txt # tests/peg-parser/tests.h # tests/test-chat-peg-parser.cpp # tests/test-chat-template.cpp # tests/test-chat.cpp # tests/testing.h # tools/llama-bench/llama-bench.cpp
This commit is contained in:
commit
8855a7f52b
23 changed files with 6659 additions and 3690 deletions
276
common/chat.cpp
276
common/chat.cpp
|
|
@ -6,11 +6,24 @@
|
|||
#include "json-schema-to-grammar.h"
|
||||
#include "log.h"
|
||||
#include "json-partial.cpp"
|
||||
#include "minja/chat-template.hpp"
|
||||
#include "minja/minja.hpp"
|
||||
#include "regex-partial.cpp"
|
||||
#include "chat-parser-xml-toolcall.cpp"
|
||||
|
||||
// #include <minja/chat-template.hpp>
|
||||
// #include <minja/minja.hpp>
|
||||
|
||||
#include "jinja/parser.h"
|
||||
#include "jinja/value.h"
|
||||
#include "jinja/runtime.h"
|
||||
#include "jinja/caps.h"
|
||||
|
||||
#include "jinja/lexer.cpp"
|
||||
#include "jinja/parser.cpp"
|
||||
#include "jinja/runtime.cpp"
|
||||
#include "jinja/value.cpp"
|
||||
#include "jinja/string.cpp"
|
||||
#include "jinja/caps.cpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <cctype>
|
||||
|
|
@ -136,7 +149,68 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
|
|||
return diffs;
|
||||
}
|
||||
|
||||
typedef minja::chat_template common_chat_template;
|
||||
using chat_template_caps = jinja::caps;
|
||||
|
||||
struct common_chat_template {
|
||||
jinja::program prog;
|
||||
std::string bos_tok;
|
||||
std::string eos_tok;
|
||||
std::string src;
|
||||
chat_template_caps caps;
|
||||
|
||||
common_chat_template(const std::string & src, const std::string & bos_token, const std::string & eos_token) {
|
||||
jinja::lexer lexer;
|
||||
auto lexer_res = lexer.tokenize(src);
|
||||
this->prog = jinja::parse_from_tokens(lexer_res);
|
||||
|
||||
this->src = lexer_res.source;
|
||||
this->bos_tok = bos_token;
|
||||
this->eos_tok = eos_token;
|
||||
|
||||
this->caps = jinja::caps_get(prog);
|
||||
// LOG_INF("%s: caps:\n%s\n", __func__, this->caps.to_string().c_str());
|
||||
}
|
||||
|
||||
const std::string & source() const { return src; }
|
||||
const std::string & bos_token() const { return bos_tok; }
|
||||
const std::string & eos_token() const { return eos_tok; }
|
||||
|
||||
// TODO: this is ugly, refactor it somehow
|
||||
json add_system(const json & messages, const std::string & system_prompt) const {
|
||||
GGML_ASSERT(messages.is_array());
|
||||
auto msgs_copy = messages;
|
||||
if (!caps.supports_system_role) {
|
||||
if (msgs_copy.empty()) {
|
||||
msgs_copy.insert(msgs_copy.begin(), json{
|
||||
{"role", "user"},
|
||||
{"content", system_prompt}
|
||||
});
|
||||
} else {
|
||||
auto & first_msg = msgs_copy[0];
|
||||
if (!first_msg.contains("content")) {
|
||||
first_msg["content"] = "";
|
||||
}
|
||||
first_msg["content"] = system_prompt + "\n\n"
|
||||
+ first_msg["content"].get<std::string>();
|
||||
}
|
||||
} else {
|
||||
if (msgs_copy.empty() || msgs_copy[0].at("role") != "system") {
|
||||
msgs_copy.insert(msgs_copy.begin(), json{
|
||||
{"role", "system"},
|
||||
{"content", system_prompt}
|
||||
});
|
||||
} else if (msgs_copy[0].at("role") == "system") {
|
||||
msgs_copy[0]["content"] = system_prompt;
|
||||
}
|
||||
}
|
||||
return msgs_copy;
|
||||
}
|
||||
|
||||
chat_template_caps original_caps() const {
|
||||
return caps;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct common_chat_templates {
|
||||
bool add_bos;
|
||||
|
|
@ -162,6 +236,7 @@ struct templates_params {
|
|||
bool add_bos;
|
||||
bool add_eos;
|
||||
bool is_inference = true;
|
||||
bool mark_input = true; // whether to mark input strings in the jinja context
|
||||
};
|
||||
|
||||
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
|
||||
|
|
@ -633,14 +708,16 @@ common_chat_templates_ptr common_chat_templates_init(
|
|||
tmpls->add_bos = add_bos;
|
||||
tmpls->add_eos = add_eos;
|
||||
try {
|
||||
tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
|
||||
tmpls->template_default = std::make_unique<common_chat_template>(default_template_src, token_bos, token_eos);
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what());
|
||||
tmpls->template_default = std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
|
||||
LOG_ERR("%s: error: %s\n", __func__, e.what());
|
||||
LOG_ERR("%s: failed to initialize chat template\n", __func__);
|
||||
LOG_ERR("%s: please consider disabling jinja via --no-jinja, or using another chat template\n", __func__);
|
||||
throw e;
|
||||
}
|
||||
if (!template_tool_use_src.empty()) {
|
||||
try {
|
||||
tmpls->template_tool_use = std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos);
|
||||
tmpls->template_tool_use = std::make_unique<common_chat_template>(template_tool_use_src, token_bos, token_eos);
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
|
||||
}
|
||||
|
|
@ -745,27 +822,43 @@ static std::string apply(
|
|||
const std::optional<json> & tools_override = std::nullopt,
|
||||
const std::optional<json> & additional_context = std::nullopt)
|
||||
{
|
||||
minja::chat_template_inputs tmpl_inputs;
|
||||
tmpl_inputs.messages = messages_override ? *messages_override : inputs.messages;
|
||||
if (tools_override) {
|
||||
tmpl_inputs.tools = *tools_override;
|
||||
} else {
|
||||
tmpl_inputs.tools = inputs.tools.empty() ? json() : inputs.tools;
|
||||
}
|
||||
tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt;
|
||||
tmpl_inputs.extra_context = inputs.extra_context;
|
||||
tmpl_inputs.extra_context["enable_thinking"] = inputs.enable_thinking;
|
||||
if (additional_context) {
|
||||
tmpl_inputs.extra_context.merge_patch(*additional_context);
|
||||
}
|
||||
// TODO: add flag to control date/time, if only for testing purposes.
|
||||
// tmpl_inputs.now = std::chrono::system_clock::now();
|
||||
jinja::context ctx(tmpl.source());
|
||||
|
||||
minja::chat_template_options tmpl_opts;
|
||||
// To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
|
||||
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
|
||||
// may be needed inside the template / between messages too.
|
||||
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
|
||||
nlohmann::ordered_json inp = nlohmann::ordered_json{
|
||||
{"messages", messages_override.has_value() ? *messages_override : inputs.messages},
|
||||
{"tools", tools_override.has_value() ? *tools_override : inputs.tools},
|
||||
{"bos_token", tmpl.bos_token()},
|
||||
{"eos_token", tmpl.eos_token()},
|
||||
};
|
||||
if (inputs.extra_context.is_object()) {
|
||||
// TODO: do we need to merge, or replacing is fine?
|
||||
for (const auto & [k, v] : inputs.extra_context.items()) {
|
||||
inp[k] = v;
|
||||
}
|
||||
}
|
||||
if (additional_context.has_value()) {
|
||||
// TODO: merge properly instead of overwriting (matching old behavior)
|
||||
for (const auto & [k, v] : additional_context->items()) {
|
||||
inp[k] = v;
|
||||
}
|
||||
}
|
||||
if (inputs.add_generation_prompt) {
|
||||
inp["add_generation_prompt"] = true;
|
||||
}
|
||||
if (inp["tools"].is_null()) {
|
||||
inp["tools"] = json::array();
|
||||
}
|
||||
|
||||
jinja::global_from_json(ctx, inp, inputs.mark_input);
|
||||
|
||||
// render
|
||||
jinja::runtime runtime(ctx);
|
||||
const jinja::value results = runtime.execute(tmpl.prog);
|
||||
auto parts = runtime.gather_string_parts(results);
|
||||
|
||||
std::string result = parts->as_string().str();
|
||||
|
||||
// TODO: improve this later
|
||||
if (inputs.add_bos && string_starts_with(result, tmpl.bos_token())) {
|
||||
result = result.substr(tmpl.bos_token().size());
|
||||
}
|
||||
|
|
@ -852,10 +945,17 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
|
|||
builder.add_schema("root", schema);
|
||||
});
|
||||
|
||||
auto tweaked_messages = common_chat_template::add_system(
|
||||
auto tweaked_messages = tmpl.add_system(
|
||||
inputs.messages,
|
||||
"Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request");
|
||||
|
||||
// ensure all messages has "content" field
|
||||
for (auto & message : tweaked_messages) {
|
||||
if (!message.contains("content") || message["content"].is_null()) {
|
||||
message["content"] = "";
|
||||
}
|
||||
}
|
||||
|
||||
data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages);
|
||||
data.format = COMMON_CHAT_FORMAT_GENERIC;
|
||||
return data;
|
||||
|
|
@ -1370,7 +1470,7 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
|
|||
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, json {
|
||||
{"date_string", format_time(inputs.now, "%d %b %Y")},
|
||||
{"tools_in_user_message", false},
|
||||
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
||||
{"builtin_tools", builtin_tools},
|
||||
});
|
||||
return data;
|
||||
}
|
||||
|
|
@ -2675,6 +2775,107 @@ static common_chat_params common_chat_params_init_seed_oss(
|
|||
return data;
|
||||
}
|
||||
|
||||
// various workarounds for known issues with certain templates or model behaviors
|
||||
// TODO @ngxson : improve this (how?)
|
||||
namespace workaround {
|
||||
|
||||
// if first message is system and template does not support it, merge it with next message
|
||||
static void system_message_not_supported(json & messages) {
|
||||
if (!messages.empty() && messages.front().at("role") == "system") {
|
||||
if (messages.size() > 1) {
|
||||
LOG_DBG("Merging system prompt into next message\n");
|
||||
auto & first_msg = messages.front();
|
||||
auto & second_msg = messages[1];
|
||||
second_msg["content"] = first_msg.at("content").get<std::string>()
|
||||
+ "\n" + second_msg.at("content").get<std::string>();
|
||||
messages.erase(messages.begin());
|
||||
} else {
|
||||
LOG_WRN("Removing system prompt due to template not supporting system role\n");
|
||||
messages.erase(messages.begin());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void func_args_not_string(json & messages) {
|
||||
GGML_ASSERT(messages.is_array());
|
||||
for (auto & message : messages) {
|
||||
if (message.contains("tool_calls")) {
|
||||
for (auto & tool_call : message["tool_calls"]) {
|
||||
if (tool_call.contains("function") && tool_call["function"].contains("arguments")) {
|
||||
auto & args = tool_call["function"]["arguments"];
|
||||
if (args.is_string()) {
|
||||
try {
|
||||
args = json::parse(args.get<std::string>());
|
||||
} catch (const std::exception & e) {
|
||||
throw std::runtime_error("Failed to parse tool call arguments as JSON: " + std::string(e.what()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void move_tool_calls_to_content(json & messages, int indent_spaces = 2) {
|
||||
GGML_ASSERT(messages.is_array());
|
||||
for (auto & message : messages) {
|
||||
if (message.contains("tool_calls")) {
|
||||
auto tool_calls_new = json{
|
||||
{"tool_calls", message.at("tool_calls")}
|
||||
};
|
||||
message.erase("tool_calls");
|
||||
auto content = message.at("content");
|
||||
std::string content_new = content.is_null() ? "" : content.get<std::string>();
|
||||
message["content"] = content_new + tool_calls_new.dump(indent_spaces, ' ', false, json::error_handler_t::replace);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO @ngxson : we may remove support for generic schema in the future
|
||||
static void use_generic_schema(json & messages) {
|
||||
GGML_ASSERT(messages.is_array());
|
||||
for (auto & message : messages) {
|
||||
if (message.contains("tool_calls") && message.at("tool_calls").is_array()) {
|
||||
auto & tool_calls = message.at("tool_calls");
|
||||
for (auto & tool_call : tool_calls) {
|
||||
if (tool_call.contains("type") && tool_call.at("type") == "function" &&
|
||||
tool_call.contains("function") && tool_call.at("function").is_object()) {
|
||||
// Copy values before erasing to avoid use-after-free
|
||||
json name_value;
|
||||
json arguments_value;
|
||||
json id_value;
|
||||
const auto & function = tool_call.at("function");
|
||||
if (function.contains("name")) {
|
||||
name_value = function.at("name");
|
||||
}
|
||||
if (function.contains("arguments")) {
|
||||
arguments_value = function.at("arguments");
|
||||
}
|
||||
if (tool_call.contains("id")) {
|
||||
id_value = tool_call.at("id");
|
||||
}
|
||||
// Now safely erase and assign in the correct order
|
||||
tool_call.erase("type");
|
||||
tool_call.erase("function");
|
||||
tool_call.erase("id");
|
||||
// Reassign in desired order: name, arguments, id
|
||||
if (!name_value.is_null()) {
|
||||
tool_call["name"] = name_value;
|
||||
}
|
||||
if (!arguments_value.is_null()) {
|
||||
tool_call["arguments"] = arguments_value;
|
||||
}
|
||||
if (!id_value.is_null()) {
|
||||
tool_call["id"] = id_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace workaround
|
||||
|
||||
static common_chat_params common_chat_templates_apply_jinja(
|
||||
const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs)
|
||||
|
|
@ -2696,6 +2897,10 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|||
params.add_bos = tmpls->add_bos;
|
||||
params.add_eos = tmpls->add_eos;
|
||||
|
||||
if (!tmpl.original_caps().supports_system_role) {
|
||||
workaround::system_message_not_supported(params.messages);
|
||||
}
|
||||
|
||||
params.extra_context = json::object();
|
||||
for (auto el : inputs.chat_template_kwargs) {
|
||||
params.extra_context[el.first] = json::parse(el.second);
|
||||
|
|
@ -2734,11 +2939,15 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|||
|
||||
// Command R7B: : use handler in all cases except json schema (thinking / tools).
|
||||
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) {
|
||||
workaround::func_args_not_string(params.messages);
|
||||
return common_chat_params_init_command_r7b(tmpl, params);
|
||||
}
|
||||
|
||||
// Granite (IBM) - detects thinking / tools support
|
||||
if (src.find("elif thinking") != std::string::npos && src.find("<|tool_call|>") != std::string::npos) {
|
||||
workaround::func_args_not_string(params.messages);
|
||||
workaround::use_generic_schema(params.messages);
|
||||
workaround::move_tool_calls_to_content(params.messages);
|
||||
return common_chat_params_init_granite(tmpl, params);
|
||||
}
|
||||
|
||||
|
|
@ -2747,6 +2956,7 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|||
src.find("<arg_key>") != std::string::npos &&
|
||||
src.find("<arg_value>") != std::string::npos &&
|
||||
params.json_schema.is_null()) {
|
||||
workaround::func_args_not_string(params.messages);
|
||||
return common_chat_params_init_glm_4_5(tmpl, params);
|
||||
}
|
||||
|
||||
|
|
@ -2758,6 +2968,7 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|||
src.find("<function=") != std::string::npos &&
|
||||
src.find("<parameters>") != std::string::npos &&
|
||||
src.find("<parameter=") != std::string::npos) {
|
||||
workaround::func_args_not_string(params.messages);
|
||||
// Nemotron 3 Nano 30B A3B
|
||||
if (src.find("<think>") != std::string::npos) {
|
||||
return common_chat_params_init_nemotron_v3(tmpl, params);
|
||||
|
|
@ -2794,6 +3005,7 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|||
|
||||
// Seed-OSS
|
||||
if (src.find("<seed:think>") != std::string::npos) {
|
||||
workaround::func_args_not_string(params.messages);
|
||||
return common_chat_params_init_seed_oss(tmpl, params, inputs);
|
||||
}
|
||||
|
||||
|
|
@ -2815,6 +3027,7 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|||
|
||||
// MiniMax-M2 format detection
|
||||
if (src.find("]~!b[") != std::string::npos && src.find("]~b]") != std::string::npos) {
|
||||
workaround::func_args_not_string(params.messages);
|
||||
return common_chat_params_init_minimax_m2(tmpl, params);
|
||||
}
|
||||
|
||||
|
|
@ -2861,6 +3074,7 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|||
// Llama 3.1, 3.2, 3.3 (also requires date_string so using it even w/o tools)
|
||||
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
||||
workaround::func_args_not_string(params.messages);
|
||||
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
|
||||
}
|
||||
|
||||
|
|
@ -2889,10 +3103,14 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|||
|
||||
// Mistral Nemo (w/ tools)
|
||||
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
||||
workaround::func_args_not_string(params.messages);
|
||||
return common_chat_params_init_mistral_nemo(tmpl, params);
|
||||
}
|
||||
|
||||
// Generic fallback
|
||||
workaround::func_args_not_string(params.messages);
|
||||
workaround::use_generic_schema(params.messages);
|
||||
workaround::move_tool_calls_to_content(params.messages);
|
||||
return common_chat_params_init_generic(tmpl, params);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue