mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-04-28 11:40:43 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .github/workflows/ai-issues.yml # CONTRIBUTING.md # docs/autoparser.md # docs/ops.md # docs/ops/Metal.csv # ggml/src/ggml-cann/aclnn_ops.cpp # ggml/src/ggml-cann/ggml-cann.cpp # ggml/src/ggml-cpu/CMakeLists.txt # ggml/src/ggml-hexagon/ggml-hexagon.cpp # ggml/src/ggml-hexagon/htp/CMakeLists.txt # ggml/src/ggml-hexagon/htp/hex-dma.h # ggml/src/ggml-hexagon/htp/hex-utils.h # ggml/src/ggml-hexagon/htp/htp-ctx.h # ggml/src/ggml-hexagon/htp/htp-msg.h # ggml/src/ggml-hexagon/htp/htp_iface.idl # ggml/src/ggml-hexagon/htp/hvx-base.h # ggml/src/ggml-hexagon/htp/main.c # ggml/src/ggml-hip/CMakeLists.txt # models/templates/Apriel-1.6-15b-Thinker-fixed.jinja # models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja # models/templates/deepseek-ai-DeepSeek-V3.1.jinja # models/templates/llama-cpp-deepseek-r1.jinja # models/templates/meetkai-functionary-medium-v3.1.jinja # scripts/fetch_server_test_models.py # scripts/snapdragon/adb/run-cli.sh # scripts/snapdragon/adb/run-completion.sh # scripts/snapdragon/adb/run-mtmd.sh # scripts/snapdragon/adb/run-tool.sh # tests/test-chat-auto-parser.cpp # tests/test-chat-peg-parser.cpp # tests/test-chat.cpp # tools/cli/cli.cpp # tools/server/README.md
This commit is contained in:
commit
6054bacadd
33 changed files with 834 additions and 491 deletions
|
|
@ -1833,23 +1833,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--grammar"}, "GRAMMAR",
|
||||
string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sampling.grammar.c_str()),
|
||||
"BNF-like grammar to constrain generations (see samples in grammars/ dir)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.grammar = value;
|
||||
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, value};
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--grammar-file"}, "FNAME",
|
||||
"file to read grammar from",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.grammar = read_file(value);
|
||||
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, read_file(value)};
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"-j", "--json-schema"}, "SCHEMA",
|
||||
"JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.grammar = json_schema_to_grammar(json::parse(value));
|
||||
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, json_schema_to_grammar(json::parse(value))};
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
|
|
@ -1866,7 +1866,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
std::istreambuf_iterator<char>(),
|
||||
std::back_inserter(schema)
|
||||
);
|
||||
params.sampling.grammar = json_schema_to_grammar(json::parse(schema));
|
||||
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, json_schema_to_grammar(json::parse(schema))};
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
#include "chat-auto-parser-helpers.h"
|
||||
#include "chat-auto-parser.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "chat.h"
|
||||
|
|
@ -23,13 +24,13 @@ static void foreach_function(const json & tools, const std::function<void(const
|
|||
|
||||
namespace autoparser {
|
||||
|
||||
parser_build_context::parser_build_context(common_chat_peg_builder & p, const templates_params & inputs) :
|
||||
parser_build_context::parser_build_context(common_chat_peg_builder & p, const generation_params & inputs) :
|
||||
p(p),
|
||||
inputs(inputs),
|
||||
reasoning_parser(p.eps()) {}
|
||||
|
||||
common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl,
|
||||
const struct templates_params & inputs) {
|
||||
const struct generation_params & inputs) {
|
||||
// Run differential analysis to extract template structure
|
||||
struct autoparser autoparser;
|
||||
autoparser.analyze_template(tmpl);
|
||||
|
|
@ -37,17 +38,16 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
|
|||
}
|
||||
|
||||
common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl,
|
||||
const struct templates_params & inputs,
|
||||
const struct generation_params & inputs,
|
||||
const autoparser & autoparser) {
|
||||
// Build the parser using the analysis results
|
||||
auto parser = autoparser.build_parser(inputs);
|
||||
|
||||
// Create the result structure
|
||||
common_chat_params data;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = autoparser.preserved_tokens;
|
||||
data.parser = parser.save();
|
||||
|
||||
auto parser = autoparser.build_parser(inputs);
|
||||
data.parser = parser.save();
|
||||
|
||||
// Build grammar if tools are present
|
||||
bool has_tools =
|
||||
|
|
@ -82,44 +82,38 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
|
|||
return data;
|
||||
}
|
||||
|
||||
common_peg_arena autoparser::build_parser(const templates_params & inputs) const {
|
||||
common_peg_arena autoparser::build_parser(const generation_params & inputs) const {
|
||||
if (!analysis_complete) {
|
||||
throw std::invalid_argument("Cannot call build_parser on autoparser without performing analysis first, call analyze_template(...)");
|
||||
}
|
||||
return build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
// If the template uses Python dict format (single-quoted strings in JSON structures),
|
||||
// pre-register a json-string rule that accepts both quote styles. This must happen
|
||||
// before any call to p.json() so that all JSON parsing inherits the flexible rule.
|
||||
if (tools.format.uses_python_dicts) {
|
||||
p.rule("json-string", p.quoted_string());
|
||||
}
|
||||
|
||||
parser_build_context ctx(p, inputs);
|
||||
bool extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
bool enable_thinking = inputs.enable_thinking;
|
||||
|
||||
ctx.extracting_reasoning = extract_reasoning && enable_thinking && reasoning.mode != reasoning_mode::NONE;
|
||||
ctx.extracting_reasoning = extract_reasoning && reasoning.mode != reasoning_mode::NONE;
|
||||
ctx.content = &content;
|
||||
|
||||
// Build reasoning parser
|
||||
ctx.reasoning_parser = reasoning.build_parser(ctx);
|
||||
|
||||
auto parser = p.eps();
|
||||
|
||||
bool has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
bool has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty();
|
||||
|
||||
if (has_response_format) {
|
||||
auto response_format = p.rule("response-format", p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
|
||||
return ctx.reasoning_parser + p.space() + p.choice({
|
||||
parser = ctx.reasoning_parser + p.space() + p.choice({
|
||||
p.literal("```json") + p.space() + response_format + p.space() + p.literal("```"),
|
||||
response_format
|
||||
}) + p.end();
|
||||
} else if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) {
|
||||
parser = tools.build_parser(ctx);
|
||||
} else {
|
||||
parser = content.build_parser(ctx);
|
||||
}
|
||||
|
||||
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) {
|
||||
return tools.build_parser(ctx);
|
||||
}
|
||||
|
||||
return content.build_parser(ctx);
|
||||
parser = wrap_for_generation_prompt(p, parser, inputs, reasoning.start);
|
||||
return parser;
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -130,24 +124,15 @@ common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) co
|
|||
return p.eps();
|
||||
}
|
||||
|
||||
bool thinking_forced_open = (mode == reasoning_mode::FORCED_OPEN);
|
||||
bool thinking_forced_closed = (mode == reasoning_mode::FORCED_CLOSED);
|
||||
|
||||
if (thinking_forced_open || thinking_forced_closed) {
|
||||
// Thinking is forced open OR forced closed with enable_thinking=true
|
||||
// In both cases, expect only the closing tag (opening was in template)
|
||||
// However, since we might have incorrectly detected the open/close pattern,
|
||||
// we admit an optional starting marker
|
||||
return p.optional(p.literal(start)) + p.reasoning(p.until(end)) + end;
|
||||
}
|
||||
if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) {
|
||||
// Standard tag-based reasoning OR tools-only mode (reasoning appears with tools)
|
||||
// Both use the same tag-based pattern if markers are available
|
||||
if (!start.empty() && !end.empty()) {
|
||||
return p.optional(start + p.reasoning(p.until(end)) + end);
|
||||
if (!end.empty()) {
|
||||
if (!start.empty()) {
|
||||
// Standard tag-based: optional(<think>reasoning</think>)
|
||||
return p.optional(start + p.reasoning(p.until(end)) + end + p.space());
|
||||
}
|
||||
// Delimiter-style (empty start)
|
||||
return p.optional(p.reasoning(p.until(end)) + end + p.space());
|
||||
}
|
||||
} else if (mode == reasoning_mode::DELIMITER) {
|
||||
return p.optional(p.reasoning(p.until(end)) + end);
|
||||
}
|
||||
|
||||
return p.eps();
|
||||
|
|
@ -335,7 +320,7 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
|||
"tool-" + name + "-arg-" + param_name + "-schema",
|
||||
param_schema, true)) :
|
||||
p.tool_arg_json_value(p.schema(
|
||||
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, format.uses_python_dicts)) +
|
||||
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
|
||||
p.space()) +
|
||||
p.tool_arg_close(p.literal(arguments.value_suffix)));
|
||||
|
||||
|
|
@ -384,7 +369,9 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
|||
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section) + p.space() + args_seq;
|
||||
matched_atomic = true;
|
||||
} else if (!arguments.name_prefix.empty() && properties.size() > 0) {
|
||||
} else if (!arguments.name_prefix.empty() && !required_parsers.empty()) {
|
||||
// Only peek for an arg tag when there are required args that must follow.
|
||||
// When all args are optional, the model may emit no arg tags at all (#20650).
|
||||
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section + p.space() + p.peek(p.literal(arguments.name_prefix))) + args_seq;
|
||||
matched_atomic = true;
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
#include "chat-auto-parser-helpers.h"
|
||||
|
||||
#include "chat-auto-parser.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "chat.h"
|
||||
#include "log.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include <cctype>
|
||||
#include <numeric>
|
||||
|
|
@ -186,6 +188,21 @@ diff_split calculate_diff_split(const std::string & left, const std::string & ri
|
|||
result.suffix = "";
|
||||
// pick prefix = all as representation
|
||||
}
|
||||
|
||||
// When left has no unique content (result.left is empty), left is entirely
|
||||
// shared with right. The simultaneous prefix/suffix segment matching can
|
||||
// incorrectly consume trailing segments of left as suffix when those same
|
||||
// segments also appear at the end of right (e.g. "\n" at the end of both
|
||||
// the shared content and the generation prompt). This rotates the diff.
|
||||
// Fix: if left is a prefix of right, enforce that directly.
|
||||
if (result.left.empty() && !result.right.empty() &&
|
||||
left.size() <= right.size() &&
|
||||
right.substr(0, left.size()) == left) {
|
||||
result.prefix = left;
|
||||
result.suffix = "";
|
||||
result.right = right.substr(left.size());
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
@ -291,10 +308,26 @@ std::vector<segment> prune_whitespace_segments(const std::vector<segment> & segm
|
|||
return result;
|
||||
}
|
||||
|
||||
common_peg_parser wrap_for_generation_prompt(common_chat_peg_builder & p,
|
||||
const common_peg_parser & prs,
|
||||
const autoparser::generation_params & inputs,
|
||||
const std::string & reasoning_start) {
|
||||
auto parser = prs;
|
||||
if (!inputs.generation_prompt.empty()) {
|
||||
size_t end_pos = inputs.generation_prompt.size();
|
||||
if (!reasoning_start.empty() && inputs.generation_prompt.find(reasoning_start) != std::string::npos) {
|
||||
end_pos = inputs.generation_prompt.find(reasoning_start);
|
||||
}
|
||||
std::string cut_genprompt = inputs.generation_prompt.substr(0, end_pos);
|
||||
parser = p.literal(cut_genprompt) + parser;
|
||||
}
|
||||
return parser;
|
||||
}
|
||||
|
||||
namespace autoparser {
|
||||
|
||||
std::string apply_template(const common_chat_template & tmpl, const template_params & params) {
|
||||
templates_params tmpl_params;
|
||||
generation_params tmpl_params;
|
||||
tmpl_params.messages = params.messages;
|
||||
tmpl_params.tools = params.tools;
|
||||
tmpl_params.add_generation_prompt = params.add_generation_prompt;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "chat-auto-parser.h"
|
||||
#include "peg-parser.h"
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
|
@ -57,6 +58,11 @@ std::vector<segment> segmentize_markers(const std::string & text);
|
|||
// (MARKER, "</function>"), (MARKER, "</tool_call>") ]
|
||||
std::vector<segment> prune_whitespace_segments(const std::vector<segment> & segments);
|
||||
|
||||
// Wrap parser with generation prompt parser
|
||||
common_peg_parser wrap_for_generation_prompt(common_chat_peg_builder & p,
|
||||
const common_peg_parser & prs,
|
||||
const autoparser::generation_params & inputs,
|
||||
const std::string & reasoning_start = {});
|
||||
namespace autoparser {
|
||||
|
||||
// Apply a template with the given parameters, returning the rendered string (empty on failure)
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ namespace autoparser {
|
|||
// High-level params for parser generation
|
||||
// ============================================================================
|
||||
|
||||
struct templates_params {
|
||||
struct generation_params {
|
||||
json messages;
|
||||
json tools;
|
||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
|
|
@ -62,6 +62,7 @@ struct templates_params {
|
|||
bool add_generation_prompt = false;
|
||||
bool enable_thinking = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
std::string generation_prompt;
|
||||
json extra_context;
|
||||
bool add_bos = false;
|
||||
bool add_eos = false;
|
||||
|
|
@ -77,11 +78,7 @@ struct templates_params {
|
|||
// Reasoning handling mode (derived from R1-R3 comparisons)
|
||||
enum class reasoning_mode {
|
||||
NONE, // No reasoning markers detected
|
||||
TAG_BASED, // Standard tag-based: <think>...</think>
|
||||
DELIMITER, // Delimiter-based: [BEGIN FINAL RESPONSE] (reasoning ends at delimiter)
|
||||
FORCED_OPEN, // Template ends with open reasoning tag (empty start, non-empty end)
|
||||
FORCED_CLOSED, // Template ends with open reasoning tag on enabled thinking but
|
||||
// with both opened and closed tag for disabled thinking
|
||||
TAG_BASED, // Tag-based: <think>...</think> (start can be empty for delimiter-style)
|
||||
TOOLS_ONLY // Only reason on tool calls, not on normal content
|
||||
};
|
||||
|
||||
|
|
@ -91,12 +88,6 @@ inline std::ostream & operator<<(std::ostream & os, const reasoning_mode & mode)
|
|||
return os << "NONE";
|
||||
case reasoning_mode::TAG_BASED:
|
||||
return os << "TAG_BASED";
|
||||
case reasoning_mode::DELIMITER:
|
||||
return os << "DELIMITER";
|
||||
case reasoning_mode::FORCED_OPEN:
|
||||
return os << "FORCED_OPEN";
|
||||
case reasoning_mode::FORCED_CLOSED:
|
||||
return os << "FORCED_CLOSED";
|
||||
case reasoning_mode::TOOLS_ONLY:
|
||||
return os << "TOOLS_ONLY";
|
||||
default:
|
||||
|
|
@ -184,7 +175,6 @@ struct tool_format_analysis {
|
|||
|
||||
bool fun_name_is_key = false; // In JSON format function name is JSON key, i.e. { "<funname>": { ... arguments ... } }
|
||||
bool tools_array_wrapped = false; // Tool calls wrapped in JSON array [...]
|
||||
bool uses_python_dicts = false; // Tool call args use Python dict format (single-quoted strings)
|
||||
|
||||
std::string function_field = "function";
|
||||
std::string name_field = "name";
|
||||
|
|
@ -225,12 +215,12 @@ struct analyze_content;
|
|||
|
||||
struct parser_build_context {
|
||||
common_chat_peg_builder & p;
|
||||
const templates_params & inputs;
|
||||
const generation_params & inputs;
|
||||
common_peg_parser reasoning_parser;
|
||||
bool extracting_reasoning = false;
|
||||
const analyze_content * content = nullptr;
|
||||
|
||||
parser_build_context(common_chat_peg_builder & p, const templates_params & inputs);
|
||||
parser_build_context(common_chat_peg_builder & p, const generation_params & inputs);
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
|
|
@ -260,6 +250,7 @@ struct analyze_reasoning : analyze_base {
|
|||
|
||||
analyze_reasoning() = default;
|
||||
analyze_reasoning(const common_chat_template & tmpl, bool supports_tools);
|
||||
analyze_reasoning(std::string start_, std::string end_) : start(std::move(start_)), end(std::move(end_)) {}
|
||||
|
||||
common_peg_parser build_parser(parser_build_context & ctx) const override;
|
||||
|
||||
|
|
@ -381,7 +372,7 @@ struct autoparser {
|
|||
void analyze_template(const common_chat_template & tmpl);
|
||||
|
||||
// Build the PEG parser for this template
|
||||
common_peg_arena build_parser(const templates_params & inputs) const;
|
||||
common_peg_arena build_parser(const generation_params & inputs) const;
|
||||
|
||||
private:
|
||||
// Collect tokens from entire analysis to preserve
|
||||
|
|
@ -395,10 +386,10 @@ struct autoparser {
|
|||
class peg_generator {
|
||||
public:
|
||||
static common_chat_params generate_parser(const common_chat_template & tmpl,
|
||||
const struct templates_params & inputs);
|
||||
const struct generation_params & inputs);
|
||||
|
||||
static common_chat_params generate_parser(const common_chat_template & tmpl,
|
||||
const struct templates_params & inputs,
|
||||
const struct generation_params & inputs,
|
||||
const autoparser & autoparser);
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
#include "chat-auto-parser-helpers.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "peg-parser.h"
|
||||
|
|
@ -31,8 +32,9 @@ static std::vector<std::function<void(const common_chat_template & tmpl, autopar
|
|||
[](const common_chat_template & tmpl, autoparser & analysis) -> void {
|
||||
if (tmpl.src.find("content.split('</think>')") != std::string::npos &&
|
||||
tmpl.src.find("reasoning_content") == std::string::npos &&
|
||||
tmpl.src.find("<SPECIAL_12>") == std::string::npos &&
|
||||
analysis.reasoning.mode == reasoning_mode::NONE) {
|
||||
analysis.reasoning.mode = reasoning_mode::FORCED_OPEN;
|
||||
analysis.reasoning.mode = reasoning_mode::TAG_BASED;
|
||||
analysis.reasoning.start = "<think>";
|
||||
analysis.reasoning.end = "</think>";
|
||||
analysis.preserved_tokens.push_back("<think>");
|
||||
|
|
@ -185,7 +187,6 @@ void autoparser::analyze_template(const common_chat_template & tmpl) {
|
|||
LOG_DBG("func_name_prefix: '%s'\n", tools.function.name_prefix.c_str());
|
||||
LOG_DBG("func_name_suffix: '%s'\n", tools.function.name_suffix.c_str());
|
||||
LOG_DBG("func_close: '%s'\n", tools.function.close.c_str());
|
||||
LOG_DBG("python_dict_format: %s\n", tools.format.uses_python_dicts ? "true" : "false");
|
||||
LOG_DBG("arg_name_prefix: '%s'\n", tools.arguments.name_prefix.c_str());
|
||||
LOG_DBG("arg_name_suffix: '%s'\n", tools.arguments.name_suffix.c_str());
|
||||
LOG_DBG("arg_value_prefix: '%s'\n", tools.arguments.value_prefix.c_str());
|
||||
|
|
@ -295,16 +296,12 @@ void analyze_reasoning::compare_reasoning_presence() {
|
|||
}
|
||||
if (result.result.success()) {
|
||||
if (!result.tags["pre"].empty() && !result.tags["post"].empty()) {
|
||||
if (parser_wrapped.parse_anywhere_and_extract(diff.right).result.success()) { // both tags in the diff = no forced close
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
} else {
|
||||
mode = reasoning_mode::FORCED_CLOSED;
|
||||
}
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
start = trim_whitespace(result.tags["pre"]);
|
||||
end = result.tags["post"];
|
||||
end = trim_trailing_whitespace(result.tags["post"]);
|
||||
} else if (!result.tags["post"].empty()) {
|
||||
mode = reasoning_mode::DELIMITER;
|
||||
end = result.tags["post"];
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
end = trim_trailing_whitespace(result.tags["post"]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -331,53 +328,30 @@ void analyze_reasoning::compare_thinking_enabled() {
|
|||
const auto & diff = comparison->diff;
|
||||
|
||||
std::string left_trimmed = trim_whitespace(diff.left);
|
||||
std::string right_trimmed = trim_whitespace(diff.right);
|
||||
|
||||
if (left_trimmed.empty() && !diff.right.empty()) {
|
||||
std::string right_trimmed = trim_whitespace(diff.right);
|
||||
|
||||
if (!right_trimmed.empty() && string_ends_with(comparison->output_B, right_trimmed)) {
|
||||
if (start.empty()) {
|
||||
start = right_trimmed;
|
||||
mode = reasoning_mode::FORCED_OPEN;
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
}
|
||||
}
|
||||
} else if (right_trimmed.empty() && !diff.left.empty()) {
|
||||
if (!left_trimmed.empty() && string_ends_with(comparison->output_A, left_trimmed)) {
|
||||
if (end.empty()) {
|
||||
auto seg = prune_whitespace_segments(segmentize_markers(comparison->output_A));
|
||||
if (seg.size() >= 2 && seg[seg.size() - 1].value == left_trimmed && seg[seg.size() - 2].type == segment_type::MARKER) {
|
||||
start = seg[seg.size() - 2].value;
|
||||
}
|
||||
end = left_trimmed;
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (start.empty() && !end.empty()) {
|
||||
mode = reasoning_mode::DELIMITER;
|
||||
}
|
||||
|
||||
// Check for FORCED_CLOSED: when enable_thinking=false produces both start and end markers,
|
||||
// but enable_thinking=true produces only the start marker
|
||||
if (!comparison->output_A.empty() && !comparison->output_B.empty()) {
|
||||
auto parser_start = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
|
||||
return p.literal(start) + p.space() + p.literal(end) + p.rest();
|
||||
});
|
||||
auto parser_start_end = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
|
||||
return p.tag("pre", p.literal(start)) + p.space() + p.negate(p.literal(end)) + p.rest();
|
||||
});
|
||||
if (!start.empty() && parser_start_end.parse_anywhere_and_extract(comparison->output_A).result.success() &&
|
||||
parser_start.parse_anywhere_and_extract(comparison->output_B).result.success()) {
|
||||
mode = reasoning_mode::FORCED_CLOSED;
|
||||
} else if (!end.empty()) { // we extract the starting marker now since we didn't get it earlier
|
||||
auto result = parser_start_end.parse_anywhere_and_extract(comparison->output_A);
|
||||
if (result.result.success()) {
|
||||
start = result.tags["pre"];
|
||||
mode = reasoning_mode::FORCED_CLOSED;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (start.empty() && end.empty()) { // we might still have the case of "just open" and "just close"
|
||||
if (!diff.left.empty() && !diff.right.empty()) {
|
||||
auto seg_A = segmentize_markers(trim_trailing_whitespace(diff.left));
|
||||
auto seg_B = segmentize_markers(trim_trailing_whitespace(diff.right));
|
||||
if (seg_A.size() == 1 && seg_B.size() == 1) {
|
||||
mode = reasoning_mode::FORCED_CLOSED;
|
||||
start = seg_B[0].value;
|
||||
end = seg_A[0].value;
|
||||
}
|
||||
}
|
||||
if (mode == reasoning_mode::NONE && start.empty() && !end.empty()) {
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -426,16 +400,16 @@ void analyze_reasoning::compare_reasoning_scope() {
|
|||
auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B);
|
||||
if (result.result.success()) {
|
||||
start = result.tags["pre"];
|
||||
end = result.tags["post"];
|
||||
end = trim_trailing_whitespace(result.tags["post"]);
|
||||
} else {
|
||||
auto parser_delimiter = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
|
||||
return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space())));
|
||||
});
|
||||
result = parser_delimiter.parse_anywhere_and_extract(comparison->output_B);
|
||||
if (result.result.success()) {
|
||||
end = result.tags["post"];
|
||||
end = trim_trailing_whitespace(result.tags["post"]);
|
||||
} else {
|
||||
LOG_DBG(ANSI_ORANGE "%s: Unable to extracft reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__);
|
||||
LOG_DBG(ANSI_ORANGE "%s: Unable to extract reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__);
|
||||
mode = reasoning_mode::NONE;
|
||||
}
|
||||
}
|
||||
|
|
@ -600,33 +574,23 @@ void analyze_tools::analyze_tool_call_format(const std::string & haystack,
|
|||
return;
|
||||
}
|
||||
|
||||
enum class json_quote_style { NONE, DOUBLE_QUOTES, SINGLE_QUOTES };
|
||||
|
||||
auto in_json_haystack = [&haystack](const std::string & needle) -> json_quote_style {
|
||||
auto in_json_haystack = [&haystack](const std::string & needle) -> bool {
|
||||
auto parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
|
||||
return p.choice({ p.literal("{"), p.literal(":") }) << p.choice({
|
||||
p.tag("sq", p.literal("'") + p.literal(needle) + p.literal("'")),
|
||||
p.tag("dq", p.literal("\"") + p.literal(needle) + p.literal("\"")) });
|
||||
});
|
||||
auto result = parser.parse_anywhere_and_extract(haystack);
|
||||
if (!result.result.success()) {
|
||||
return json_quote_style::NONE;
|
||||
}
|
||||
return result.tags.count("sq") && !result.tags["sq"].empty()
|
||||
? json_quote_style::SINGLE_QUOTES
|
||||
: json_quote_style::DOUBLE_QUOTES;
|
||||
return result.result.success();
|
||||
};
|
||||
|
||||
auto fun_quote = in_json_haystack(fun_name_needle);
|
||||
auto arg_quote = in_json_haystack(arg_name_needle);
|
||||
|
||||
if (fun_quote != json_quote_style::NONE) {
|
||||
if (fun_quote) {
|
||||
// no need to check further, we're in JSON land
|
||||
format.mode = tool_format::JSON_NATIVE;
|
||||
format.uses_python_dicts = (fun_quote == json_quote_style::SINGLE_QUOTES);
|
||||
} else if (arg_quote != json_quote_style::NONE) {
|
||||
} else if (arg_quote) {
|
||||
format.mode = tool_format::TAG_WITH_JSON;
|
||||
format.uses_python_dicts = (arg_quote == json_quote_style::SINGLE_QUOTES);
|
||||
} else {
|
||||
format.mode = tool_format::TAG_WITH_TAGGED;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -229,6 +229,20 @@ void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena,
|
|||
result.tool_calls.push_back(pending_tool_call.value());
|
||||
pending_tool_call.reset();
|
||||
}
|
||||
|
||||
// Discard whitespace-only reasoning content (e.g. from <think></think> prefill)
|
||||
if (!result.reasoning_content.empty()) {
|
||||
bool all_whitespace = true;
|
||||
for (char c : result.reasoning_content) {
|
||||
if (c != ' ' && c != '\n' && c != '\r' && c != '\t') {
|
||||
all_whitespace = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (all_whitespace) {
|
||||
result.reasoning_content.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
|
||||
|
|
|
|||
318
common/chat.cpp
318
common/chat.cpp
|
|
@ -1,5 +1,6 @@
|
|||
#include "chat.h"
|
||||
|
||||
#include "chat-auto-parser-helpers.h"
|
||||
#include "chat-auto-parser.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "common.h"
|
||||
|
|
@ -33,6 +34,7 @@
|
|||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
|
@ -775,7 +777,7 @@ static void foreach_parameter(const json &
|
|||
|
||||
std::string common_chat_template_direct_apply(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs,
|
||||
const autoparser::generation_params & inputs,
|
||||
const std::optional<json> & messages_override,
|
||||
const std::optional<json> & tools_override,
|
||||
const std::optional<json> & additional_context) {
|
||||
|
|
@ -826,7 +828,7 @@ std::string common_chat_template_direct_apply(
|
|||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs) {
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
// Build up messages to follow the format: https://huggingface.co/mistralai/Ministral-3-14B-Reasoning-2512/blob/main/chat_template.jinja
|
||||
|
|
@ -891,8 +893,8 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
|||
// Response format parser
|
||||
if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
|
||||
// Ministral wants to emit json surrounded by code fences
|
||||
return reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema))
|
||||
<< "```";
|
||||
return wrap_for_generation_prompt(p, reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) << "```",
|
||||
inputs, "[THINK]");
|
||||
}
|
||||
|
||||
// Tool call parser
|
||||
|
|
@ -912,12 +914,13 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
|||
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
|
||||
auto tool_calls = p.trigger_rule("tool-call", p.repeat("[TOOL_CALLS]" + tool_choice, min_calls, max_calls));
|
||||
|
||||
return reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls;
|
||||
return wrap_for_generation_prompt(p, reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls,
|
||||
inputs, "[THINK]");
|
||||
}
|
||||
|
||||
// Content only parser
|
||||
include_grammar = false;
|
||||
return reasoning << p.content(p.rest());
|
||||
return wrap_for_generation_prompt(p, reasoning << p.content(p.rest()), inputs, "[THINK]");
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
|
@ -943,7 +946,7 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
|||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs) {
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
// Copy reasoning to the "thinking" field as expected by the gpt-oss template
|
||||
|
|
@ -1003,7 +1006,8 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
|||
p.literal("<|channel|>final") + constraint + p.literal("<|message|>") +
|
||||
p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
|
||||
|
||||
return response_format | (analysis + p.zero_or_more(start + analysis) + start + response_format);
|
||||
return wrap_for_generation_prompt(p, response_format | (analysis + p.zero_or_more(start + analysis) + start + response_format),
|
||||
inputs, "<|channel|>");
|
||||
}
|
||||
|
||||
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
|
|
@ -1035,10 +1039,12 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
|||
return tool_call | ( any + p.zero_or_more(start + any) + start + tool_call);
|
||||
}
|
||||
|
||||
return tool_call | final_msg | (any + p.zero_or_more(start + any) + start + (tool_call | final_msg));
|
||||
return wrap_for_generation_prompt(p, tool_call | final_msg | (any + p.zero_or_more(start + any) + start + (tool_call | final_msg)),
|
||||
inputs, "<|channel|>");
|
||||
}
|
||||
|
||||
return final_msg | (any + p.zero_or_more(start + any) + start + final_msg);
|
||||
return wrap_for_generation_prompt(p, final_msg | (any + p.zero_or_more(start + any) + start + final_msg),
|
||||
inputs, "<|channel|>");
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
|
@ -1066,7 +1072,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
|||
|
||||
// Functionary v3.2 - uses recipient-based format: >>>recipient\n{content}
|
||||
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs) {
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
|
|
@ -1087,13 +1093,13 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
|||
// Build content parser for >>>all\n{content}
|
||||
// When tools are present, content stops before the next ">>>" (tool call)
|
||||
// When no tools, content goes until end
|
||||
auto content_until_tool = p.literal(">>>all\n") + p.content(p.until(">>>"));
|
||||
auto content_until_end = p.literal(">>>all\n") + p.content(p.rest());
|
||||
auto content_until_tool = p.literal("all\n") + p.content(p.until(">>>"));
|
||||
auto content_until_end = p.literal("all\n") + p.content(p.rest());
|
||||
|
||||
// If no tools or tool_choice is NONE, just parse content
|
||||
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
// When no tools, just match the prefix and capture everything after
|
||||
return content_until_end + p.end();
|
||||
return wrap_for_generation_prompt(p, content_until_end + p.end(), inputs);
|
||||
}
|
||||
|
||||
// Build tool call parsers for each available function
|
||||
|
|
@ -1105,7 +1111,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
|||
|
||||
// Tool format: >>>function_name\n{json_args}
|
||||
auto tool_parser = p.tool(
|
||||
p.tool_open(p.literal(">>>") + p.tool_name(p.literal(name)) + p.literal("\n")) +
|
||||
p.tool_open(p.tool_name(p.literal(name)) + p.literal("\n")) +
|
||||
p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema))
|
||||
);
|
||||
|
||||
|
|
@ -1116,17 +1122,20 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
|||
auto tools_only = p.trigger_rule("tools", p.one_or_more(tool_choice));
|
||||
auto content_and_tools = content_until_tool + tools_only;
|
||||
|
||||
auto ret = p.eps();
|
||||
if (inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED) {
|
||||
if (inputs.parallel_tool_calls) {
|
||||
return p.choice({ content_and_tools, tools_only }) + p.end();
|
||||
ret = p.choice({ content_and_tools, tools_only }) + p.end();
|
||||
} else {
|
||||
ret = p.choice({ content_until_tool + tool_choice, tools_only }) + p.end();
|
||||
}
|
||||
return p.choice({ content_until_tool + tool_choice, tools_only }) + p.end();
|
||||
} else if (inputs.parallel_tool_calls) {
|
||||
ret = p.choice({ content_and_tools, content_only, tools_only }) + p.end();
|
||||
} else {
|
||||
auto content_and_tool = content_until_tool + tool_choice;
|
||||
ret = p.choice({ content_and_tool, content_only, tool_choice }) + p.end();
|
||||
}
|
||||
if (inputs.parallel_tool_calls) {
|
||||
return p.choice({ content_and_tools, content_only, tools_only }) + p.end();
|
||||
}
|
||||
auto content_and_tool = content_until_tool + tool_choice;
|
||||
return p.choice({ content_and_tool, content_only, tool_choice }) + p.end();
|
||||
return wrap_for_generation_prompt(p, ret, inputs);
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
|
@ -1156,14 +1165,12 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
|||
// Kimi K2 Thinking - uses unique tool call ID format: functions.<name>:<index>
|
||||
// The ID contains both the function name and an incrementing counter
|
||||
static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs) {
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "<think>";
|
||||
data.thinking_end_tag = "</think>";
|
||||
data.preserved_tokens = {
|
||||
"<|tool_calls_section_begin|>",
|
||||
"<|tool_calls_section_end|>",
|
||||
|
|
@ -1178,6 +1185,18 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
|||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
|
||||
const std::string SECTION_BEGIN = "<|tool_calls_section_begin|>";
|
||||
const std::string SECTION_END = "<|tool_calls_section_end|>";
|
||||
const std::string CALL_BEGIN = "<|tool_call_begin|>";
|
||||
const std::string ARGS_BEGIN = "<|tool_call_argument_begin|>";
|
||||
const std::string CALL_END = "<|tool_call_end|>";
|
||||
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
|
||||
data.thinking_start_tag = THINK_START;
|
||||
data.thinking_end_tag = THINK_END;
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
// Kimi K2 Thinking format:
|
||||
// - Reasoning: <think>{reasoning}</think>
|
||||
|
|
@ -1189,16 +1208,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
|||
// <|tool_calls_section_end|>
|
||||
// The ID format is: functions.<function_name>:<counter> where counter is 0, 1, 2, ...
|
||||
|
||||
// Tool call markers
|
||||
const std::string SECTION_BEGIN = "<|tool_calls_section_begin|>";
|
||||
const std::string SECTION_END = "<|tool_calls_section_end|>";
|
||||
const std::string CALL_BEGIN = "<|tool_call_begin|>";
|
||||
const std::string ARGS_BEGIN = "<|tool_call_argument_begin|>";
|
||||
const std::string CALL_END = "<|tool_call_end|>";
|
||||
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
|
||||
// Tool call markers
|
||||
auto end = p.end();
|
||||
|
||||
// Note: this model is CRAZY. It can diverge from its supposed tool calling pattern in so many ways it's not funny.
|
||||
|
|
@ -1210,7 +1220,8 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
|||
|
||||
// Content only parser (no tools)
|
||||
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return reasoning + p.content(p.rest()) + end;
|
||||
return wrap_for_generation_prompt(p, reasoning + p.content(p.rest()) + end,
|
||||
inputs, THINK_START);
|
||||
}
|
||||
|
||||
// Build tool call parsers for each available function
|
||||
|
|
@ -1246,7 +1257,8 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
|||
|
||||
auto content_before_tools = p.content(p.until_one_of({ SECTION_BEGIN, CALL_BEGIN }));
|
||||
|
||||
return reasoning + content_before_tools + tool_calls + end;
|
||||
return wrap_for_generation_prompt(p, reasoning + content_before_tools + tool_calls + end,
|
||||
inputs, THINK_START);
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
|
@ -1276,7 +1288,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
|||
// - Tool calls: <|tool_call_start|>[function_name(arg1="value1", arg2="value2")]<|tool_call_end|>
|
||||
// Tool calls can appear multiple times (parallel tool calls)
|
||||
static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs) {
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
|
|
@ -1295,13 +1307,15 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
|||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
|
||||
|
||||
const std::string TOOL_CALL_START = "<|tool_call_start|>";
|
||||
const std::string TOOL_CALL_END = "<|tool_call_end|>";
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
|
||||
data.thinking_start_tag = THINK_START;
|
||||
data.thinking_end_tag = THINK_END;
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto end = p.end();
|
||||
|
||||
auto reasoning = p.eps();
|
||||
|
|
@ -1310,7 +1324,8 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
|||
}
|
||||
|
||||
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return reasoning + p.content(p.rest()) + end;
|
||||
return wrap_for_generation_prompt(p, reasoning + p.content(p.rest()) + end, inputs,
|
||||
THINK_START);
|
||||
}
|
||||
|
||||
auto tool_calls = p.rule("tool-calls",
|
||||
|
|
@ -1322,7 +1337,8 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
|||
|
||||
auto content = p.content(p.until(TOOL_CALL_START));
|
||||
|
||||
return reasoning + content + tool_calls + end;
|
||||
return wrap_for_generation_prompt(p, reasoning + content + tool_calls + end, inputs,
|
||||
THINK_START);
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
|
@ -1348,7 +1364,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
|||
|
||||
static common_chat_params common_chat_params_init_gigachat_v3(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs) {
|
||||
const autoparser::generation_params & inputs) {
|
||||
|
||||
common_chat_params data;
|
||||
|
||||
|
|
@ -1362,9 +1378,10 @@ static common_chat_params common_chat_params_init_gigachat_v3(
|
|||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
auto tool_call_start_prefix = "<|message_sep|>\n\nfunction call<|role_sep|>\n";
|
||||
const auto *tool_call_start_prefix = "<|message_sep|>\n\nfunction call<|role_sep|>\n";
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto ret = p.eps();
|
||||
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
// Build a choice of all available tools
|
||||
auto tool_choice = p.choice();
|
||||
|
|
@ -1387,13 +1404,14 @@ static common_chat_params common_chat_params_init_gigachat_v3(
|
|||
auto tool_call = p.rule("tool-call", p.literal(tool_call_start_prefix) + tool_choice);
|
||||
auto tool_calls = p.trigger_rule("tool-call-root", p.repeat(tool_call, /* min = */ min_calls, /* max = */ max_calls));
|
||||
|
||||
return p.content(p.until("<|message_sep|>\n\n")) << tool_calls;
|
||||
ret = p.content(p.until("<|message_sep|>\n\n")) << tool_calls;
|
||||
} else {
|
||||
// Content only parser
|
||||
include_grammar = false;
|
||||
ret = p.content(p.rest());
|
||||
}
|
||||
|
||||
// Content only parser
|
||||
include_grammar = false;
|
||||
return p.content(p.rest());
|
||||
|
||||
return wrap_for_generation_prompt(p, ret, inputs);
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
|
@ -1488,87 +1506,10 @@ static json common_chat_extra_context() {
|
|||
return ctx;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs) {
|
||||
autoparser::templates_params params;
|
||||
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools);
|
||||
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
|
||||
? *tmpls->template_tool_use
|
||||
: *tmpls->template_default;
|
||||
const auto & src = tmpl.source();
|
||||
const auto & caps = tmpl.original_caps();
|
||||
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
|
||||
params.add_generation_prompt = inputs.add_generation_prompt;
|
||||
params.tool_choice = inputs.tool_choice;
|
||||
params.reasoning_format = inputs.reasoning_format;
|
||||
params.enable_thinking = inputs.enable_thinking;
|
||||
params.grammar = inputs.grammar;
|
||||
params.now = inputs.now;
|
||||
params.add_bos = tmpls->add_bos;
|
||||
params.add_eos = tmpls->add_eos;
|
||||
|
||||
if (src.find("<|channel|>") == std::string::npos) {
|
||||
// map developer to system for all models except for GPT-OSS
|
||||
workaround::map_developer_role_to_system(params.messages);
|
||||
}
|
||||
|
||||
if (!tmpl.original_caps().supports_system_role) {
|
||||
workaround::system_message_not_supported(params.messages);
|
||||
}
|
||||
|
||||
if (tmpl.original_caps().supports_tool_calls) {
|
||||
// some templates will require the content field in tool call messages
|
||||
// to still be non-null, this puts an empty string everywhere where the
|
||||
// content field is null
|
||||
workaround::requires_non_null_content(params.messages);
|
||||
}
|
||||
|
||||
if (tmpl.original_caps().supports_object_arguments) {
|
||||
workaround::func_args_not_string(params.messages);
|
||||
}
|
||||
|
||||
params.extra_context = common_chat_extra_context();
|
||||
for (auto el : inputs.chat_template_kwargs) {
|
||||
params.extra_context[el.first] = json::parse(el.second);
|
||||
}
|
||||
|
||||
if (!inputs.json_schema.empty()) {
|
||||
params.json_schema = json::parse(inputs.json_schema);
|
||||
}
|
||||
|
||||
// if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
|
||||
// LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
|
||||
// params.parallel_tool_calls = false;
|
||||
// } else {
|
||||
params.parallel_tool_calls = inputs.parallel_tool_calls;
|
||||
//}
|
||||
|
||||
if (params.tools.is_array()) {
|
||||
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
|
||||
throw std::runtime_error("Cannot specify grammar with tools");
|
||||
}
|
||||
if (caps.supports_tool_calls && !caps.supports_tools) {
|
||||
LOG_WRN(
|
||||
"Template supports tool calls but does not natively describe tools. The fallback behaviour used may "
|
||||
"produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
|
||||
}
|
||||
}
|
||||
|
||||
if (inputs.force_pure_content) {
|
||||
LOG_WRN("Forcing pure content template, will not render reasoning or tools separately.");
|
||||
// Create the result structure
|
||||
common_chat_params data;
|
||||
auto params_copy = params;
|
||||
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, params_copy);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
auto parser = build_chat_peg_parser([](common_chat_peg_builder &p) {
|
||||
return p.content(p.rest());
|
||||
});
|
||||
data.parser = parser.save();
|
||||
return data;
|
||||
}
|
||||
|
||||
static std::optional<common_chat_params> try_specialized_template(
|
||||
const common_chat_template & tmpl,
|
||||
const std::string & src,
|
||||
const autoparser::generation_params & params) {
|
||||
// Ministral/Mistral Large 3 - uses special reasoning structure fixes, can't use autoparser
|
||||
// Note: Mistral Small 3.2 uses [CALL_ID] which Ministral doesn't have, so we can distinguish them
|
||||
if (src.find("[SYSTEM_PROMPT]") != std::string::npos && src.find("[TOOL_CALLS]") != std::string::npos &&
|
||||
|
|
@ -1609,14 +1550,105 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
|||
// GigaChatV3 format detection
|
||||
if (src.find("<|role_sep|>") != std::string::npos &&
|
||||
src.find("<|message_sep|>") != std::string::npos &&
|
||||
src.find("<|function_call|>") == std::string::npos
|
||||
) {
|
||||
src.find("<|function_call|>") == std::string::npos) {
|
||||
LOG_DBG("Using specialized template: GigaChatV3\n");
|
||||
return common_chat_params_init_gigachat_v3(tmpl, params);
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs) {
|
||||
autoparser::generation_params params;
|
||||
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools);
|
||||
const auto & tmpl =
|
||||
params.tools.is_array() && tmpls->template_tool_use ? *tmpls->template_tool_use : *tmpls->template_default;
|
||||
const auto & src = tmpl.source();
|
||||
const auto & caps = tmpl.original_caps();
|
||||
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
|
||||
params.tool_choice = inputs.tool_choice;
|
||||
params.reasoning_format = inputs.reasoning_format;
|
||||
params.enable_thinking = inputs.enable_thinking;
|
||||
params.grammar = inputs.grammar;
|
||||
params.now = inputs.now;
|
||||
params.add_bos = tmpls->add_bos;
|
||||
params.add_eos = tmpls->add_eos;
|
||||
|
||||
if (src.find("<|channel|>") == std::string::npos) {
|
||||
// map developer to system for all models except for GPT-OSS
|
||||
workaround::map_developer_role_to_system(params.messages);
|
||||
}
|
||||
|
||||
if (!tmpl.original_caps().supports_system_role) {
|
||||
workaround::system_message_not_supported(params.messages);
|
||||
}
|
||||
|
||||
if (tmpl.original_caps().supports_tool_calls) {
|
||||
// some templates will require the content field in tool call messages
|
||||
// to still be non-null, this puts an empty string everywhere where the
|
||||
// content field is null
|
||||
workaround::requires_non_null_content(params.messages);
|
||||
}
|
||||
|
||||
if (tmpl.original_caps().supports_object_arguments) {
|
||||
workaround::func_args_not_string(params.messages);
|
||||
}
|
||||
|
||||
params.add_generation_prompt = false;
|
||||
std::string no_gen_prompt = common_chat_template_direct_apply(tmpl, params);
|
||||
params.add_generation_prompt = true;
|
||||
std::string gen_prompt = common_chat_template_direct_apply(tmpl, params);
|
||||
auto diff = calculate_diff_split(no_gen_prompt, gen_prompt);
|
||||
params.generation_prompt = diff.right;
|
||||
|
||||
params.add_generation_prompt = inputs.add_generation_prompt;
|
||||
|
||||
params.extra_context = common_chat_extra_context();
|
||||
for (auto el : inputs.chat_template_kwargs) {
|
||||
params.extra_context[el.first] = json::parse(el.second);
|
||||
}
|
||||
|
||||
if (!inputs.json_schema.empty()) {
|
||||
params.json_schema = json::parse(inputs.json_schema);
|
||||
}
|
||||
|
||||
params.parallel_tool_calls = inputs.parallel_tool_calls;
|
||||
|
||||
if (params.tools.is_array()) {
|
||||
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
|
||||
throw std::runtime_error("Cannot specify grammar with tools");
|
||||
}
|
||||
if (caps.supports_tool_calls && !caps.supports_tools) {
|
||||
LOG_WRN(
|
||||
"Template supports tool calls but does not natively describe tools. The fallback behaviour used may "
|
||||
"produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
|
||||
}
|
||||
}
|
||||
|
||||
if (inputs.force_pure_content) {
|
||||
LOG_WRN("Forcing pure content template, will not render reasoning or tools separately.");
|
||||
// Create the result structure
|
||||
common_chat_params data;
|
||||
auto params_copy = params;
|
||||
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, params_copy);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.generation_prompt = params.generation_prompt;
|
||||
auto parser = build_chat_peg_parser([¶ms](common_chat_peg_builder &p) {
|
||||
return wrap_for_generation_prompt(p, p.content(p.rest()), params);
|
||||
});
|
||||
data.parser = parser.save();
|
||||
return data;
|
||||
}
|
||||
|
||||
if (auto result = try_specialized_template(tmpl, src, params)) {
|
||||
result->generation_prompt = params.generation_prompt;
|
||||
return *result;
|
||||
}
|
||||
|
||||
try {
|
||||
LOG_DBG("Using differential autoparser\n");
|
||||
LOG_DBG("%s: using differential autoparser\n", __func__);
|
||||
struct autoparser::autoparser autoparser;
|
||||
autoparser.analyze_template(tmpl);
|
||||
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
|
||||
|
|
@ -1624,13 +1656,11 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
|||
if (auto_params.supports_thinking) {
|
||||
auto_params.thinking_start_tag = autoparser.reasoning.start;
|
||||
auto_params.thinking_end_tag = autoparser.reasoning.end;
|
||||
// FORCED_OPEN and FORCED_CLOSED both put <think> in the generation prompt
|
||||
// (FORCED_CLOSED forces empty <think></think> when thinking is disabled,
|
||||
// but forces <think> open when thinking is enabled)
|
||||
auto_params.thinking_forced_open =
|
||||
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_OPEN ||
|
||||
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_CLOSED;
|
||||
}
|
||||
auto_params.generation_prompt = params.generation_prompt;
|
||||
common_peg_arena arena;
|
||||
arena.load(auto_params.parser);
|
||||
LOG_DBG("%s: generated parser:\n%s\n\nparser generation prompt: %s\n", __func__, arena.dump(arena.root()).c_str(), auto_params.generation_prompt.c_str());
|
||||
return auto_params;
|
||||
} catch (const std::exception & e) {
|
||||
throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what());
|
||||
|
|
@ -1728,14 +1758,18 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
|||
LOG_DBG("No parser definition detected, assuming pure content parser.");
|
||||
}
|
||||
|
||||
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), input.c_str());
|
||||
const std::string effective_input = params.generation_prompt.empty()
|
||||
? input
|
||||
: params.generation_prompt + input;
|
||||
|
||||
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), effective_input.c_str());
|
||||
|
||||
common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_LENIENT;
|
||||
if (params.debug) {
|
||||
flags |= COMMON_PEG_PARSE_FLAG_DEBUG;
|
||||
}
|
||||
|
||||
common_peg_parse_context ctx(input, flags);
|
||||
common_peg_parse_context ctx(effective_input, flags);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
if (result.fail()) {
|
||||
|
|
@ -1755,7 +1789,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
|||
return msg;
|
||||
}
|
||||
throw std::runtime_error(std::string("Failed to parse input at pos ") + std::to_string(result.end) + ": " +
|
||||
input.substr(result.end));
|
||||
effective_input.substr(result.end));
|
||||
}
|
||||
|
||||
common_chat_msg msg;
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ using json = nlohmann::ordered_json;
|
|||
struct common_chat_templates;
|
||||
|
||||
namespace autoparser {
|
||||
struct templates_params;
|
||||
struct generation_params;
|
||||
} // namespace autoparser
|
||||
|
||||
struct common_chat_tool_call {
|
||||
|
|
@ -212,7 +212,7 @@ struct common_chat_params {
|
|||
std::string prompt;
|
||||
std::string grammar;
|
||||
bool grammar_lazy = false;
|
||||
bool thinking_forced_open = false;
|
||||
std::string generation_prompt;
|
||||
bool supports_thinking = false;
|
||||
std::string thinking_start_tag; // e.g., "<think>"
|
||||
std::string thinking_end_tag; // e.g., "</think>"
|
||||
|
|
@ -229,14 +229,14 @@ struct common_chat_parser_params {
|
|||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning"
|
||||
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
|
||||
bool reasoning_in_content = false;
|
||||
bool thinking_forced_open = false;
|
||||
std::string generation_prompt;
|
||||
bool parse_tool_calls = true;
|
||||
bool debug = false; // Enable debug output for PEG parser
|
||||
common_peg_arena parser = {};
|
||||
common_chat_parser_params() = default;
|
||||
common_chat_parser_params(const common_chat_params & chat_params) {
|
||||
format = chat_params.format;
|
||||
thinking_forced_open = chat_params.thinking_forced_open;
|
||||
format = chat_params.format;
|
||||
generation_prompt = chat_params.generation_prompt;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -302,7 +302,7 @@ std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_tem
|
|||
|
||||
std::string common_chat_template_direct_apply(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs,
|
||||
const autoparser::generation_params & inputs,
|
||||
const std::optional<json> & messages_override = std::nullopt,
|
||||
const std::optional<json> & tools_override = std::nullopt,
|
||||
const std::optional<json> & additional_context = std::nullopt);
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml-opt.h"
|
||||
#include "ggml.h"
|
||||
#include "llama-cpp.h"
|
||||
#include "build-info.h"
|
||||
|
||||
|
|
@ -10,6 +11,7 @@
|
|||
#include <sstream>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
|
|
@ -175,6 +177,43 @@ enum common_speculative_type {
|
|||
COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type
|
||||
};
|
||||
|
||||
// Grammar type enumeration
|
||||
enum common_grammar_type {
|
||||
COMMON_GRAMMAR_TYPE_NONE, // no grammar set
|
||||
COMMON_GRAMMAR_TYPE_USER, // user-provided GBNF (--grammar / "grammar" API field)
|
||||
COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, // auto-generated from JSON schema (--json-schema / "json_schema" API field)
|
||||
COMMON_GRAMMAR_TYPE_TOOL_CALLS, // auto-generated by chat template parser for function calling
|
||||
};
|
||||
|
||||
// Grammar variant struct with type and grammar string
|
||||
struct common_grammar {
|
||||
common_grammar_type type = COMMON_GRAMMAR_TYPE_NONE;
|
||||
std::string grammar;
|
||||
|
||||
// Default constructor - no grammar
|
||||
common_grammar() = default;
|
||||
|
||||
// Constructor with type and grammar string
|
||||
common_grammar(common_grammar_type t, std::string g) : type(t), grammar(std::move(g)) {
|
||||
GGML_ASSERT(type != COMMON_GRAMMAR_TYPE_NONE || !grammar.empty());
|
||||
}
|
||||
|
||||
// Check if a grammar is set
|
||||
bool empty() const { return type == COMMON_GRAMMAR_TYPE_NONE || grammar.empty(); }
|
||||
};
|
||||
|
||||
// Returns the raw grammar string, or empty string if no grammar is set.
|
||||
inline const std::string & common_grammar_value(const common_grammar & g) {
|
||||
return g.grammar;
|
||||
}
|
||||
|
||||
// Returns true when the generation_prompt should be prefilled into the grammar sampler.
|
||||
// Only output-format and tool-call grammars need prefill; user-supplied grammars must not be prefilled.
|
||||
inline bool common_grammar_needs_prefill(const common_grammar & g) {
|
||||
return g.type == COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT
|
||||
|| g.type == COMMON_GRAMMAR_TYPE_TOOL_CALLS;
|
||||
}
|
||||
|
||||
// sampling parameters
|
||||
struct common_params_sampling {
|
||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
||||
|
|
@ -225,7 +264,7 @@ struct common_params_sampling {
|
|||
COMMON_SAMPLER_TYPE_TEMPERATURE,
|
||||
};
|
||||
|
||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||
common_grammar grammar; // optional grammar constraint (user / output-format / tool-calls)
|
||||
bool grammar_lazy = false;
|
||||
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
|
||||
std::set<llama_token> preserved_tokens;
|
||||
|
|
@ -233,10 +272,15 @@ struct common_params_sampling {
|
|||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
||||
|
||||
// The assistant generation prompt already prefilled into the prompt.
|
||||
// Fed to the grammar sampler (to advance past pre-existing tokens) and used
|
||||
// to determine the reasoning budget sampler's initial state.
|
||||
// Only applied when the grammar is of output-format or tool-calls type.
|
||||
std::string generation_prompt;
|
||||
|
||||
// reasoning budget sampler parameters
|
||||
// these are populated by the server/CLI based on chat template params
|
||||
int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget
|
||||
bool reasoning_budget_activate_immediately = false;
|
||||
std::vector<llama_token> reasoning_budget_start; // start tag token sequence
|
||||
std::vector<llama_token> reasoning_budget_end; // end tag token sequence
|
||||
std::vector<llama_token> reasoning_budget_forced; // forced sequence (message + end tag)
|
||||
|
|
|
|||
|
|
@ -451,7 +451,7 @@ struct value_array_t : public value_t {
|
|||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_arr.begin(), val_arr.end(), other.val_arr.begin(), value_equivalence());
|
||||
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_arr.begin(), val_arr.end(), other.val_arr.begin(), other.val_arr.end(), value_equivalence());
|
||||
}
|
||||
};
|
||||
using value_array = std::shared_ptr<value_array_t>;
|
||||
|
|
@ -587,7 +587,7 @@ struct value_object_t : public value_t {
|
|||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_obj.begin(), val_obj.end(), other.val_obj.begin(), value_equivalence());
|
||||
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_obj.begin(), val_obj.end(), other.val_obj.begin(), other.val_obj.end(), value_equivalence());
|
||||
}
|
||||
};
|
||||
using value_object = std::shared_ptr<value_object_t>;
|
||||
|
|
|
|||
|
|
@ -163,9 +163,15 @@ static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
|
|||
ctx->force_pos = 0;
|
||||
}
|
||||
|
||||
// forward declaration for use in clone
|
||||
static struct llama_sampler * common_reasoning_budget_init_state(
|
||||
const struct llama_vocab * vocab, const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens, const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget, common_reasoning_budget_state initial_state);
|
||||
|
||||
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
|
||||
return common_reasoning_budget_init(
|
||||
return common_reasoning_budget_init_state(
|
||||
ctx->vocab,
|
||||
ctx->start_matcher.tokens,
|
||||
ctx->end_matcher.tokens,
|
||||
|
|
@ -191,13 +197,13 @@ static struct llama_sampler_i common_reasoning_budget_i = {
|
|||
/* .backend_set_input = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
common_reasoning_budget_state initial_state) {
|
||||
static struct llama_sampler * common_reasoning_budget_init_state(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
common_reasoning_budget_state initial_state) {
|
||||
// promote COUNTING with budget <= 0 to FORCING
|
||||
if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
|
||||
initial_state = REASONING_BUDGET_FORCING;
|
||||
|
|
@ -217,3 +223,41 @@ struct llama_sampler * common_reasoning_budget_init(
|
|||
}
|
||||
);
|
||||
}
|
||||
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
const std::vector<llama_token> & prefill_tokens) {
|
||||
// Determine initial state from prefill: COUNTING if the prefill begins with
|
||||
// the start sequence but does not also contain the end sequence after it.
|
||||
common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE;
|
||||
if (!prefill_tokens.empty() && !start_tokens.empty() &&
|
||||
prefill_tokens.size() >= start_tokens.size() &&
|
||||
std::equal(start_tokens.begin(), start_tokens.end(), prefill_tokens.begin())) {
|
||||
initial_state = REASONING_BUDGET_COUNTING;
|
||||
// If the end sequence also follows the start in the prefill, reasoning
|
||||
// was opened and immediately closed — stay IDLE.
|
||||
if (!end_tokens.empty() &&
|
||||
prefill_tokens.size() >= start_tokens.size() + end_tokens.size()) {
|
||||
auto end_start = prefill_tokens.end() - (ptrdiff_t) end_tokens.size();
|
||||
if (end_start >= prefill_tokens.begin() + (ptrdiff_t) start_tokens.size() &&
|
||||
std::equal(end_tokens.begin(), end_tokens.end(), end_start)) {
|
||||
initial_state = REASONING_BUDGET_IDLE;
|
||||
}
|
||||
}
|
||||
}
|
||||
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
|
||||
}
|
||||
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
common_reasoning_budget_state initial_state) {
|
||||
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,14 +24,26 @@ enum common_reasoning_budget_state {
|
|||
// DONE: passthrough forever
|
||||
//
|
||||
// Parameters:
|
||||
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
|
||||
// start_tokens - token sequence that activates counting
|
||||
// end_tokens - token sequence for natural deactivation
|
||||
// forced_tokens - token sequence forced when budget expires
|
||||
// budget - max tokens allowed in the reasoning block
|
||||
// initial_state - initial state of the sampler (e.g. IDLE or COUNTING)
|
||||
// note: COUNTING with budget <= 0 is promoted to FORCING
|
||||
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
|
||||
// start_tokens - token sequence that activates counting
|
||||
// end_tokens - token sequence for natural deactivation
|
||||
// forced_tokens - token sequence forced when budget expires
|
||||
// budget - max tokens allowed in the reasoning block
|
||||
// prefill_tokens - tokens already present in the prompt (generation prompt);
|
||||
// used to determine the initial state: COUNTING if they begin
|
||||
// with start_tokens (but don't also end with end_tokens),
|
||||
// IDLE otherwise. COUNTING with budget <= 0 is promoted to FORCING.
|
||||
//
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
const std::vector<llama_token> & prefill_tokens = {});
|
||||
|
||||
// Variant that takes an explicit initial state (used by tests and clone).
|
||||
// COUNTING with budget <= 0 is promoted to FORCING.
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
|
|
|
|||
|
|
@ -1,13 +1,16 @@
|
|||
#include "sampling.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "ggml.h"
|
||||
#include "log.h"
|
||||
#include "reasoning-budget.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
||||
// TODO: deduplicate with llama-impl.h
|
||||
|
|
@ -189,9 +192,10 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
|||
|
||||
std::vector<llama_sampler *> samplers;
|
||||
|
||||
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
||||
const std::string & grammar_str = common_grammar_value(params.grammar);
|
||||
if (grammar_str.compare(0, 11, "%llguidance") == 0) {
|
||||
#ifdef LLAMA_USE_LLGUIDANCE
|
||||
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
|
||||
grmr = llama_sampler_init_llg(vocab, "lark", grammar_str.c_str());
|
||||
#else
|
||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
|
|
@ -240,17 +244,46 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
|||
trigger_patterns_c.push_back(regex.c_str());
|
||||
}
|
||||
|
||||
if (!params.grammar.empty()) {
|
||||
if (!grammar_str.empty()) {
|
||||
if (params.grammar_lazy) {
|
||||
grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
||||
grmr = llama_sampler_init_grammar_lazy_patterns(vocab, grammar_str.c_str(), "root",
|
||||
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||
trigger_tokens.data(), trigger_tokens.size());
|
||||
} else {
|
||||
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||
grmr = llama_sampler_init_grammar(vocab, grammar_str.c_str(), "root");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Feed generation prompt tokens to the grammar sampler so it advances past
|
||||
// tokens the template already placed in the prompt.
|
||||
// Only applies to output-format and tool-call grammars; user-supplied grammars must not be prefilled.
|
||||
std::vector<llama_token> prefill_tokens;
|
||||
if (!params.generation_prompt.empty() && common_grammar_needs_prefill(params.grammar)) {
|
||||
GGML_ASSERT(vocab != nullptr);
|
||||
prefill_tokens = common_tokenize(vocab, params.generation_prompt, false, true);
|
||||
if (!prefill_tokens.empty()) {
|
||||
std::string first_token = common_token_to_piece(vocab, prefill_tokens[0], true);
|
||||
if (std::isspace(first_token[0]) && !std::isspace(params.generation_prompt[0])) {
|
||||
// Some tokenizers will add a space before the first special token, need to remove
|
||||
prefill_tokens = std::vector<llama_token>(prefill_tokens.begin() + 1, prefill_tokens.end());
|
||||
}
|
||||
}
|
||||
|
||||
if (grmr) {
|
||||
try {
|
||||
for (const auto & token : prefill_tokens) {
|
||||
llama_sampler_accept(grmr, token);
|
||||
LOG_DBG("%s: accepted prefill token (%d)\n", __func__, token);
|
||||
}
|
||||
} catch (std::exception &e) {
|
||||
LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__,
|
||||
common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reasoning budget sampler — added first so it can force tokens before other samplers
|
||||
if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) {
|
||||
samplers.push_back(common_reasoning_budget_init(
|
||||
|
|
@ -259,7 +292,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
|||
params.reasoning_budget_end,
|
||||
params.reasoning_budget_forced,
|
||||
params.reasoning_budget_tokens,
|
||||
params.reasoning_budget_activate_immediately ? REASONING_BUDGET_COUNTING : REASONING_BUDGET_IDLE));
|
||||
prefill_tokens));
|
||||
}
|
||||
|
||||
if (params.has_logit_bias()) {
|
||||
|
|
|
|||
|
|
@ -3194,6 +3194,7 @@ class tinyBLAS_PPC {
|
|||
|
||||
private:
|
||||
|
||||
__attribute__((always_inline))
|
||||
inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
|
||||
vec_t vec_C[4];
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
|
|
@ -3204,6 +3205,7 @@ class tinyBLAS_PPC {
|
|||
}
|
||||
}
|
||||
|
||||
__attribute__((always_inline))
|
||||
inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
|
||||
vec_t vec_C[4];
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
|
|
|
|||
|
|
@ -4620,12 +4620,42 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
{"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"},
|
||||
{"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"},
|
||||
};
|
||||
const bool use_subgroup_reduce = device->subgroup_arithmetic;
|
||||
for (uint32_t si = 0; si < 3; si++) {
|
||||
const uint32_t S_V = gdn_sizes[si];
|
||||
GGML_ASSERT(is_pow2(S_V));
|
||||
|
||||
uint32_t lanes_per_column;
|
||||
if (S_V >= 128u && device->subgroup_clustered) {
|
||||
lanes_per_column = 8u;
|
||||
} else {
|
||||
// Use largest power-of-two that divides both S_V and subgroup_size so that
|
||||
// (1) S_V % lanes_per_column == 0 and (2) S_V % (subgroup_size / lanes_per_column) == 0.
|
||||
// This means we don't need extra bounds checking logic in the shader.
|
||||
lanes_per_column = std::min(S_V, device->subgroup_size);
|
||||
}
|
||||
|
||||
const bool need_clustered_shader = lanes_per_column != 1 && (lanes_per_column < device->subgroup_size);
|
||||
size_t gdn_len;
|
||||
const void * gdn_data;
|
||||
if (use_subgroup_reduce && need_clustered_shader) {
|
||||
gdn_len = gated_delta_net_f32_len;
|
||||
gdn_data = (const void *)gated_delta_net_f32_data;
|
||||
} else if (use_subgroup_reduce) {
|
||||
gdn_len = gated_delta_net_f32_nocluster_len;
|
||||
gdn_data = (const void *)gated_delta_net_f32_nocluster_data;
|
||||
} else {
|
||||
gdn_len = gated_delta_net_f32_shmem_len;
|
||||
gdn_data = (const void *)gated_delta_net_f32_shmem_data;
|
||||
}
|
||||
|
||||
const uint32_t cols_per_wg = device->subgroup_size / lanes_per_column;
|
||||
const std::array<uint32_t, 3> wg_denoms = {1u, 1u, cols_per_wg};
|
||||
|
||||
for (uint32_t kda = 0; kda < 2; kda++) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda],
|
||||
gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data,
|
||||
"main", 7, sizeof(vk_op_gated_delta_net_push_constants),
|
||||
{1, 1, 1}, {gdn_sizes[si], kda}, 1);
|
||||
gdn_names[si][kda], gdn_len, gdn_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants),
|
||||
wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_reduce, device->subgroup_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -10476,7 +10506,7 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
|
|||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
|
||||
pc, { H, n_seqs, 1u });
|
||||
pc, { H, n_seqs, S_v });
|
||||
}
|
||||
|
||||
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
|
||||
|
|
|
|||
|
|
@ -1,11 +1,25 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#if USE_SUBGROUP_CLUSTERED
|
||||
#extension GL_KHR_shader_subgroup_clustered : enable
|
||||
#endif
|
||||
#if USE_SUBGROUP_ADD
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#endif
|
||||
|
||||
// Caller guarantees valid spec constants: S_V % COLS_PER_WG == 0 and S_V % LANES_PER_COLUMN == 0,
|
||||
// so no bounds checking is needed.
|
||||
layout(constant_id = 0) const uint S_V = 128;
|
||||
layout(constant_id = 1) const uint KDA = 0;
|
||||
layout(constant_id = 2) const uint SUBGROUP_SIZE = 32;
|
||||
layout(constant_id = 3) const uint LANES_PER_COLUMN = 32;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
const uint COLS_PER_WG = SUBGROUP_SIZE / LANES_PER_COLUMN;
|
||||
const uint ROWS_PER_LANE = S_V / LANES_PER_COLUMN;
|
||||
|
||||
layout(local_size_x_id = 2, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout(push_constant) uniform Parameters {
|
||||
uint H;
|
||||
|
|
@ -27,14 +41,61 @@ layout(binding = 4) readonly buffer BetaBuf { FLOAT_TYPE data_beta[]; };
|
|||
layout(binding = 5) readonly buffer StateBuf { FLOAT_TYPE data_state[]; };
|
||||
layout(binding = 6) buffer DstBuf { FLOAT_TYPE data_dst[]; };
|
||||
|
||||
shared FLOAT_TYPE s_k[S_V];
|
||||
shared FLOAT_TYPE s_q[S_V];
|
||||
shared FLOAT_TYPE s_g[S_V]; // KDA only: cached exp(g[i])
|
||||
#if !USE_SUBGROUP_ADD && !USE_SUBGROUP_CLUSTERED
|
||||
shared FLOAT_TYPE temp[SUBGROUP_SIZE];
|
||||
|
||||
// This does a reduction across groups of LANES_PER_COLUMN
|
||||
FLOAT_TYPE reduce_add_shmem(FLOAT_TYPE partial) {
|
||||
const uint lane = gl_SubgroupInvocationID;
|
||||
temp[lane] = partial;
|
||||
barrier();
|
||||
[[unroll]] for (uint s = LANES_PER_COLUMN / 2u; s > 0; s >>= 1u) {
|
||||
FLOAT_TYPE other = temp[lane ^ s];
|
||||
barrier();
|
||||
temp[lane] += other;
|
||||
barrier();
|
||||
}
|
||||
const FLOAT_TYPE result = temp[lane];
|
||||
barrier();
|
||||
return result;
|
||||
}
|
||||
#endif
|
||||
|
||||
// clusterSize for subgroupClusteredAdd must be a compile-time constant; branch on spec constant
|
||||
FLOAT_TYPE reduce_partial(FLOAT_TYPE partial) {
|
||||
switch (LANES_PER_COLUMN) {
|
||||
case 1u:
|
||||
return partial;
|
||||
#if USE_SUBGROUP_CLUSTERED
|
||||
// Workaround for GLSL requiring a literal constant for the cluster size.
|
||||
// The branches should all fold away.
|
||||
case 2u:
|
||||
return subgroupClusteredAdd(partial, 2u);
|
||||
case 4u:
|
||||
return subgroupClusteredAdd(partial, 4u);
|
||||
case 8u:
|
||||
return subgroupClusteredAdd(partial, 8u);
|
||||
case 16u:
|
||||
return subgroupClusteredAdd(partial, 16u);
|
||||
case 32u:
|
||||
return subgroupClusteredAdd(partial, 32u);
|
||||
case 64u:
|
||||
return subgroupClusteredAdd(partial, 64u);
|
||||
#endif
|
||||
default:
|
||||
#if USE_SUBGROUP_ADD
|
||||
return subgroupAdd(partial);
|
||||
#else
|
||||
return reduce_add_shmem(partial);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint head_id = gl_WorkGroupID.x;
|
||||
const uint seq_id = gl_WorkGroupID.y;
|
||||
const uint col = gl_LocalInvocationID.x;
|
||||
const uint seq_id = gl_WorkGroupID.y;
|
||||
const uint lane = gl_SubgroupInvocationID % LANES_PER_COLUMN;
|
||||
const uint col = gl_WorkGroupID.z * COLS_PER_WG + (gl_SubgroupInvocationID / LANES_PER_COLUMN);
|
||||
|
||||
const uint iq1 = head_id % neq1;
|
||||
const uint iq3 = seq_id / rq3;
|
||||
|
|
@ -42,9 +103,9 @@ void main() {
|
|||
const uint state_size = S_V * S_V;
|
||||
const uint state_base = (seq_id * H + head_id) * state_size;
|
||||
|
||||
FLOAT_TYPE state[S_V];
|
||||
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
||||
state[i] = FLOAT_TYPE(data_state[state_base + col * S_V + i]);
|
||||
FLOAT_TYPE s_shard[ROWS_PER_LANE];
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]);
|
||||
}
|
||||
|
||||
uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
|
||||
|
|
@ -53,76 +114,56 @@ void main() {
|
|||
const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1;
|
||||
const uint k_off = q_off;
|
||||
const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1;
|
||||
|
||||
s_q[col] = FLOAT_TYPE(data_q[q_off + col]);
|
||||
s_k[col] = FLOAT_TYPE(data_k[k_off + col]);
|
||||
|
||||
const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1;
|
||||
|
||||
if (KDA != 0) {
|
||||
const uint g_base = gb_off * S_V;
|
||||
s_g[col] = exp(FLOAT_TYPE(data_g[g_base + col]));
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);
|
||||
const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]);
|
||||
|
||||
FLOAT_TYPE k_reg[ROWS_PER_LANE];
|
||||
FLOAT_TYPE q_reg[ROWS_PER_LANE];
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
const uint i = r * LANES_PER_COLUMN + lane;
|
||||
k_reg[r] = FLOAT_TYPE(data_k[k_off + i]);
|
||||
q_reg[r] = FLOAT_TYPE(data_q[q_off + i]);
|
||||
}
|
||||
|
||||
FLOAT_TYPE g_exp[ROWS_PER_LANE];
|
||||
if (KDA == 0) {
|
||||
const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off]));
|
||||
|
||||
FLOAT_TYPE kv_col = 0.0;
|
||||
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
|
||||
kv_col += dot(
|
||||
vec4(state[i], state[i+1], state[i+2], state[i+3]),
|
||||
vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3])
|
||||
);
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
g_exp[r] = g_val;
|
||||
}
|
||||
|
||||
FLOAT_TYPE delta_col = (v_val - g_val * kv_col) * beta_val;
|
||||
|
||||
FLOAT_TYPE attn_col = 0.0;
|
||||
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
|
||||
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
|
||||
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
|
||||
sv = g_val * sv + kv * delta_col;
|
||||
state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
|
||||
|
||||
attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
|
||||
}
|
||||
|
||||
data_dst[attn_off + col] = attn_col * scale;
|
||||
} else {
|
||||
FLOAT_TYPE kv_col = 0.0;
|
||||
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
|
||||
vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
|
||||
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
|
||||
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
|
||||
kv_col += dot(gv * sv, kv);
|
||||
const uint g_base = gb_off * S_V;
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
const uint i = r * LANES_PER_COLUMN + lane;
|
||||
g_exp[r] = exp(FLOAT_TYPE(data_g[g_base + i]));
|
||||
}
|
||||
}
|
||||
|
||||
FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;
|
||||
const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);
|
||||
|
||||
FLOAT_TYPE attn_col = 0.0;
|
||||
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
|
||||
vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
|
||||
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
|
||||
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
|
||||
sv = gv * sv + kv * delta_col;
|
||||
state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
|
||||
FLOAT_TYPE kv_shard = 0.0;
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
kv_shard += g_exp[r] * s_shard[r] * k_reg[r];
|
||||
}
|
||||
FLOAT_TYPE kv_col = reduce_partial(kv_shard);
|
||||
|
||||
attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
|
||||
}
|
||||
FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;
|
||||
|
||||
FLOAT_TYPE attn_partial = 0.0;
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
s_shard[r] = g_exp[r] * s_shard[r] + k_reg[r] * delta_col;
|
||||
attn_partial += s_shard[r] * q_reg[r];
|
||||
}
|
||||
FLOAT_TYPE attn_col = reduce_partial(attn_partial);
|
||||
|
||||
if (lane == 0) {
|
||||
data_dst[attn_off + col] = attn_col * scale;
|
||||
}
|
||||
|
||||
attn_off += S_V * H;
|
||||
barrier();
|
||||
}
|
||||
|
||||
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
||||
data_dst[s_off + state_base + col * S_V + i] = state[i];
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r];
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1004,7 +1004,9 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}}));
|
||||
string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "1"}}));
|
||||
string_to_spv("gated_delta_net_f32_nocluster", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "0"}}));
|
||||
string_to_spv("gated_delta_net_f32_shmem", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "0"}, {"USE_SUBGROUP_CLUSTERED", "0"}}));
|
||||
|
||||
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
|
|
|||
|
|
@ -1948,6 +1948,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|||
LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
|
||||
return 0;
|
||||
}
|
||||
ggml_backend_buffer_clear(buf_output.get(), 0);
|
||||
}
|
||||
|
||||
float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
|
||||
|
|
|
|||
|
|
@ -1787,6 +1787,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
|
||||
// NextN/MTP parameters (GLM-OCR)
|
||||
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
|
||||
GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer");
|
||||
|
||||
// TODO: when MTP is implemented, this should probably be updated if needed
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
|
|
@ -1820,6 +1821,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
|
||||
// NextN/MTP parameters
|
||||
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
|
||||
GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer");
|
||||
|
||||
// TODO: when MTP is implemented, this should probably be updated if needed
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
|
|
@ -1866,6 +1868,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
|
||||
// NextN/MTP parameters
|
||||
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
|
||||
GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer");
|
||||
|
||||
// TODO: when MTP is implemented, this should probably be updated if needed
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
|
|
@ -2040,6 +2043,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false);
|
||||
|
||||
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
|
||||
GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer");
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 32: type = LLM_TYPE_30B_A3B; break;
|
||||
|
|
@ -2168,7 +2172,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
|
||||
switch (hparams.n_embd) {
|
||||
case 768: type = LLM_TYPE_350M; break;
|
||||
case 1536: type = (hparams.n_embd == 2048 ? LLM_TYPE_7B_A1B : LLM_TYPE_1B); break;
|
||||
case 1536: type = (hparams.n_ff() == 512 ? LLM_TYPE_7B_A1B : LLM_TYPE_1B); break;
|
||||
case 2048: case 2560: type = LLM_TYPE_3B; break;
|
||||
case 4096: type = LLM_TYPE_32B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
|
|
@ -2222,6 +2226,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
|
||||
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func);
|
||||
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
|
||||
GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer");
|
||||
|
||||
// TODO: when MTP is implemented, this should probably be updated if needed
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
|
|
|
|||
|
|
@ -282,7 +282,7 @@ static void render_scenario(const common_chat_template & tmpl,
|
|||
LOG_ERR("Messages:\n%s\n", final_messages.dump(2).c_str());
|
||||
|
||||
try {
|
||||
autoparser::templates_params inputs;
|
||||
autoparser::generation_params inputs;
|
||||
inputs.messages = final_messages;
|
||||
inputs.add_generation_prompt = add_generation_prompt;
|
||||
inputs.extra_context["enable_thinking"] = enable_thinking;
|
||||
|
|
@ -395,7 +395,7 @@ int main(int argc, char ** argv) {
|
|||
analysis.analyze_template(chat_template);
|
||||
|
||||
// Generate Parser
|
||||
autoparser::templates_params params;
|
||||
autoparser::generation_params params;
|
||||
params.messages = json::array({ build_user_message() });
|
||||
params.reasoning_format =
|
||||
opts.enable_reasoning ? COMMON_REASONING_FORMAT_DEEPSEEK : COMMON_REASONING_FORMAT_NONE;
|
||||
|
|
|
|||
|
|
@ -400,12 +400,12 @@ static void analyze_template(const std::string & template_path) {
|
|||
{
|
||||
json user_msg = make_user_msg();
|
||||
|
||||
autoparser::templates_params params_no_tools;
|
||||
autoparser::generation_params params_no_tools;
|
||||
params_no_tools.messages = json::array({ user_msg });
|
||||
params_no_tools.add_generation_prompt = false;
|
||||
params_no_tools.tools = json::array();
|
||||
|
||||
autoparser::templates_params params_with_tools = params_no_tools;
|
||||
autoparser::generation_params params_with_tools = params_no_tools;
|
||||
params_with_tools.tools = tools;
|
||||
|
||||
std::string output_no_tools = common_chat_template_direct_apply(chat_template, params_no_tools);
|
||||
|
|
@ -419,12 +419,12 @@ static void analyze_template(const std::string & template_path) {
|
|||
{
|
||||
json user_msg = make_user_msg();
|
||||
|
||||
autoparser::templates_params params_no_prompt;
|
||||
autoparser::generation_params params_no_prompt;
|
||||
params_no_prompt.messages = json::array({ user_msg });
|
||||
params_no_prompt.add_generation_prompt = false;
|
||||
params_no_prompt.tools = json::array();
|
||||
|
||||
autoparser::templates_params params_with_prompt = params_no_prompt;
|
||||
autoparser::generation_params params_with_prompt = params_no_prompt;
|
||||
params_with_prompt.add_generation_prompt = true;
|
||||
|
||||
std::string output_no_prompt = common_chat_template_direct_apply(chat_template, params_no_prompt);
|
||||
|
|
@ -438,12 +438,12 @@ static void analyze_template(const std::string & template_path) {
|
|||
{
|
||||
json user_msg = make_user_msg();
|
||||
|
||||
autoparser::templates_params params_no_reasoning;
|
||||
autoparser::generation_params params_no_reasoning;
|
||||
params_no_reasoning.messages = json::array({ user_msg, make_assistant_no_reasoning() });
|
||||
params_no_reasoning.add_generation_prompt = false;
|
||||
params_no_reasoning.enable_thinking = true;
|
||||
|
||||
autoparser::templates_params params_with_reasoning = params_no_reasoning;
|
||||
autoparser::generation_params params_with_reasoning = params_no_reasoning;
|
||||
params_with_reasoning.messages = json::array({ user_msg, make_assistant_with_reasoning() });
|
||||
|
||||
std::string output_no_reasoning = common_chat_template_direct_apply(chat_template, params_no_reasoning);
|
||||
|
|
@ -458,12 +458,12 @@ static void analyze_template(const std::string & template_path) {
|
|||
json user_msg = make_user_msg();
|
||||
json user_msg2 = make_user_msg2();
|
||||
|
||||
autoparser::templates_params params_no_reasoning;
|
||||
autoparser::generation_params params_no_reasoning;
|
||||
params_no_reasoning.messages = json::array({ user_msg, make_assistant_no_reasoning(), user_msg2 });
|
||||
params_no_reasoning.add_generation_prompt = false;
|
||||
params_no_reasoning.enable_thinking = true;
|
||||
|
||||
autoparser::templates_params params_with_reasoning = params_no_reasoning;
|
||||
autoparser::generation_params params_with_reasoning = params_no_reasoning;
|
||||
params_with_reasoning.messages = json::array({ user_msg, make_assistant_with_reasoning(), user_msg2 });
|
||||
|
||||
std::string output_no_reasoning = common_chat_template_direct_apply(chat_template, params_no_reasoning);
|
||||
|
|
@ -477,12 +477,12 @@ static void analyze_template(const std::string & template_path) {
|
|||
{
|
||||
json user_msg = make_user_msg();
|
||||
|
||||
autoparser::templates_params params_no_tool;
|
||||
autoparser::generation_params params_no_tool;
|
||||
params_no_tool.messages = json::array({ user_msg, make_assistant_no_tool() });
|
||||
params_no_tool.add_generation_prompt = false;
|
||||
params_no_tool.tools = tools;
|
||||
|
||||
autoparser::templates_params params_with_tool = params_no_tool;
|
||||
autoparser::generation_params params_with_tool = params_no_tool;
|
||||
params_with_tool.messages = json::array({ user_msg, make_assistant_one_tool() });
|
||||
|
||||
std::string output_no_tool = common_chat_template_direct_apply(chat_template, params_no_tool);
|
||||
|
|
@ -497,12 +497,12 @@ static void analyze_template(const std::string & template_path) {
|
|||
json user_msg = make_user_msg();
|
||||
json user_msg2 = make_user_msg2_continue();
|
||||
|
||||
autoparser::templates_params params_no_tool;
|
||||
autoparser::generation_params params_no_tool;
|
||||
params_no_tool.messages = json::array({ user_msg, make_assistant_no_tool(), user_msg2 });
|
||||
params_no_tool.add_generation_prompt = false;
|
||||
params_no_tool.tools = tools;
|
||||
|
||||
autoparser::templates_params params_with_tool = params_no_tool;
|
||||
autoparser::generation_params params_with_tool = params_no_tool;
|
||||
params_with_tool.messages = json::array({ user_msg, make_assistant_one_tool(), user_msg2 });
|
||||
|
||||
std::string output_no_tool = common_chat_template_direct_apply(chat_template, params_no_tool);
|
||||
|
|
@ -516,12 +516,12 @@ static void analyze_template(const std::string & template_path) {
|
|||
{
|
||||
json user_msg = make_user_msg();
|
||||
|
||||
autoparser::templates_params params_one_tool;
|
||||
autoparser::generation_params params_one_tool;
|
||||
params_one_tool.messages = json::array({ user_msg, make_assistant_one_tool() });
|
||||
params_one_tool.add_generation_prompt = false;
|
||||
params_one_tool.tools = tools;
|
||||
|
||||
autoparser::templates_params params_two_tools = params_one_tool;
|
||||
autoparser::generation_params params_two_tools = params_one_tool;
|
||||
params_two_tools.messages = json::array({ user_msg, make_assistant_two_tools() });
|
||||
|
||||
std::string output_one_tool = common_chat_template_direct_apply(chat_template, params_one_tool);
|
||||
|
|
@ -536,12 +536,12 @@ static void analyze_template(const std::string & template_path) {
|
|||
json user_msg = make_user_msg();
|
||||
json user_msg2 = make_user_msg2_continue();
|
||||
|
||||
autoparser::templates_params params_one_tool;
|
||||
autoparser::generation_params params_one_tool;
|
||||
params_one_tool.messages = json::array({ user_msg, make_assistant_one_tool(), user_msg2 });
|
||||
params_one_tool.add_generation_prompt = false;
|
||||
params_one_tool.tools = tools;
|
||||
|
||||
autoparser::templates_params params_two_tools = params_one_tool;
|
||||
autoparser::generation_params params_two_tools = params_one_tool;
|
||||
params_two_tools.messages = json::array({ user_msg, make_assistant_two_tools(), user_msg2 });
|
||||
|
||||
std::string output_one_tool = common_chat_template_direct_apply(chat_template, params_one_tool);
|
||||
|
|
@ -555,13 +555,13 @@ static void analyze_template(const std::string & template_path) {
|
|||
{
|
||||
json user_msg = make_user_msg();
|
||||
|
||||
autoparser::templates_params params_no_reasoning;
|
||||
autoparser::generation_params params_no_reasoning;
|
||||
params_no_reasoning.messages = json::array({ user_msg, make_assistant_one_tool() });
|
||||
params_no_reasoning.add_generation_prompt = false;
|
||||
params_no_reasoning.tools = tools;
|
||||
params_no_reasoning.enable_thinking = true;
|
||||
|
||||
autoparser::templates_params params_with_reasoning = params_no_reasoning;
|
||||
autoparser::generation_params params_with_reasoning = params_no_reasoning;
|
||||
params_with_reasoning.messages = json::array({ user_msg, make_assistant_one_tool_with_reasoning() });
|
||||
|
||||
std::string output_no_reasoning = common_chat_template_direct_apply(chat_template, params_no_reasoning);
|
||||
|
|
|
|||
|
|
@ -4,6 +4,36 @@ This document provides an in-depth technical overview of `llama-server`, intende
|
|||
|
||||
If you are an end user consuming `llama-server` as a product, please refer to the main [README](./README.md) instead.
|
||||
|
||||
## Scope of features
|
||||
|
||||
In-scope types of feature:
|
||||
|
||||
- Backend:
|
||||
- Basic inference features: text completion, embeddings output
|
||||
- Chat-oriented features: chat completion, tool calling
|
||||
- Third-party API compatibility, e.g. OAI-compat, Anthropic-compat
|
||||
- Multimodal input/output
|
||||
- Memory management: save/load state, context checkpoints
|
||||
- Model management
|
||||
- Features that are required by the Web UI
|
||||
- Frontend:
|
||||
- Chat-oriented features, example: basic chat, image upload, edit messages
|
||||
- Agentic features, example: MCP
|
||||
- Model management
|
||||
|
||||
Note: For security reasons, features that require reading or writing external files must be **disabled by default**. This covers features like: MCP, model save/load
|
||||
|
||||
Out-of-scope features:
|
||||
|
||||
- Backend:
|
||||
- Features that require a loop of external API calls, e.g. server-side agentic loop. This is because external API calls in C++ are costly to maintain. Any complex third-party logic should be implemented outside of server code.
|
||||
- Features that expose the internal state of the model to the API, example: getting the intermediate activation from API. This is because llama.cpp doesn't support a stable API for doing this, and relying on `eval_callback` can make it complicated to maintain as this API is not intended to be used in multi-sequence setup.
|
||||
- Model-specific features. All API calls and features must remain model-agnostic.
|
||||
- Frontend:
|
||||
- Third-party plugins, it is costly to maintain a public plugin API for such features. Instead, users can make their own MCP server for their needs.
|
||||
- Customizable themes, it is also costly to maintain. While we do focus on the aesthetic, we try to achieve this by perfecting a small set of themes.
|
||||
- Browser-specific features, example: [Chrome's built-in AI API](https://developer.chrome.com/docs/ai/built-in-apis).
|
||||
|
||||
## Backend
|
||||
|
||||
### Overview
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -1081,20 +1081,21 @@ json oaicompat_chat_params_parse(
|
|||
}
|
||||
}
|
||||
|
||||
llama_params["chat_format"] = static_cast<int>(chat_params.format);
|
||||
llama_params["prompt"] = chat_params.prompt;
|
||||
llama_params["chat_format"] = static_cast<int>(chat_params.format);
|
||||
llama_params["prompt"] = chat_params.prompt;
|
||||
if (!chat_params.grammar.empty()) {
|
||||
llama_params["grammar"] = chat_params.grammar;
|
||||
llama_params["grammar"] = chat_params.grammar;
|
||||
llama_params["grammar_type"] = std::string("tool_calls");
|
||||
}
|
||||
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
|
||||
auto grammar_triggers = json::array();
|
||||
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
|
||||
auto grammar_triggers = json::array();
|
||||
for (const auto & trigger : chat_params.grammar_triggers) {
|
||||
server_grammar_trigger ct(trigger);
|
||||
grammar_triggers.push_back(ct.to_json());
|
||||
}
|
||||
llama_params["grammar_triggers"] = grammar_triggers;
|
||||
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
|
||||
llama_params["thinking_forced_open"] = chat_params.thinking_forced_open;
|
||||
llama_params["grammar_triggers"] = grammar_triggers;
|
||||
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
|
||||
llama_params["generation_prompt"] = chat_params.generation_prompt;
|
||||
for (const auto & stop : chat_params.additional_stops) {
|
||||
llama_params["stop"].push_back(stop);
|
||||
}
|
||||
|
|
@ -1114,7 +1115,6 @@ json oaicompat_chat_params_parse(
|
|||
llama_params["reasoning_budget_start_tag"] = chat_params.thinking_start_tag;
|
||||
llama_params["reasoning_budget_end_tag"] = chat_params.thinking_end_tag;
|
||||
llama_params["reasoning_budget_message"] = opt.reasoning_budget_message;
|
||||
llama_params["reasoning_budget_activate_immediately"] = chat_params.thinking_forced_open;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <cinttypes>
|
||||
#include <exception>
|
||||
#include <memory>
|
||||
#include <filesystem>
|
||||
|
||||
|
|
@ -1152,11 +1153,11 @@ private:
|
|||
|
||||
// initialize samplers
|
||||
if (task.need_sampling()) {
|
||||
slot.smpl.reset(common_sampler_init(model, task.params.sampling));
|
||||
|
||||
if (slot.smpl == nullptr) {
|
||||
// for now, the only error that may happen here is invalid grammar
|
||||
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
||||
try {
|
||||
slot.smpl.reset(common_sampler_init(model, task.params.sampling));
|
||||
} catch (std::exception & e) {
|
||||
std::string err_msg = std::string("Failed to initialize samplers: ") + e.what();
|
||||
send_error(task, err_msg, ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -1431,9 +1432,10 @@ private:
|
|||
res->tokens = { tkn.tok };
|
||||
}
|
||||
|
||||
res->n_decoded = slot.n_decoded;
|
||||
res->n_prompt_tokens = slot.task->n_tokens();
|
||||
res->post_sampling_probs = slot.task->params.post_sampling_probs;
|
||||
res->n_decoded = slot.n_decoded;
|
||||
res->n_prompt_tokens = slot.task->n_tokens();
|
||||
res->n_prompt_tokens_cache = slot.n_prompt_tokens_cache;
|
||||
res->post_sampling_probs = slot.task->params.post_sampling_probs;
|
||||
|
||||
res->verbose = slot.task->params.verbose;
|
||||
res->res_type = slot.task->params.res_type;
|
||||
|
|
@ -1478,14 +1480,15 @@ private:
|
|||
res->prompt = slot.task->tokens.detokenize(ctx, true);
|
||||
res->response_fields = std::move(slot.task->params.response_fields);
|
||||
|
||||
res->truncated = slot.truncated;
|
||||
res->n_decoded = slot.n_decoded;
|
||||
res->n_prompt_tokens = slot.task->n_tokens();
|
||||
res->n_tokens_cached = slot.prompt.n_tokens();
|
||||
res->has_new_line = slot.has_new_line;
|
||||
res->stopping_word = slot.stopping_word;
|
||||
res->stop = slot.stop;
|
||||
res->post_sampling_probs = slot.task->params.post_sampling_probs;
|
||||
res->truncated = slot.truncated;
|
||||
res->n_decoded = slot.n_decoded;
|
||||
res->n_prompt_tokens = slot.task->n_tokens();
|
||||
res->n_prompt_tokens_cache = slot.n_prompt_tokens_cache;
|
||||
res->n_tokens_cached = slot.prompt.n_tokens();
|
||||
res->has_new_line = slot.has_new_line;
|
||||
res->stopping_word = slot.stopping_word;
|
||||
res->stop = slot.stop;
|
||||
res->post_sampling_probs = slot.task->params.post_sampling_probs;
|
||||
|
||||
res->verbose = slot.task->params.verbose;
|
||||
res->stream = slot.task->params.stream;
|
||||
|
|
@ -2304,8 +2307,8 @@ private:
|
|||
|
||||
llama_pos pos_next = slot.prompt.tokens.pos_next(n_past);
|
||||
|
||||
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
|
||||
const auto n_swa = std::max(1, llama_model_n_swa(model));
|
||||
// note: when n_swa == 0, the model does not use SWA
|
||||
const auto n_swa = std::max(0, llama_model_n_swa(model));
|
||||
|
||||
// the largest pos_min required for a checkpoint to be useful
|
||||
const auto pos_min_thold = std::max(0, pos_next - n_swa);
|
||||
|
|
@ -2360,7 +2363,7 @@ private:
|
|||
SLT_WRN(slot, "%s\n", st1.str().c_str());
|
||||
}
|
||||
|
||||
if (pos_min > pos_min_thold) {
|
||||
if (pos_min >= pos_min_thold) {
|
||||
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
|
||||
|
||||
// search for a context checkpoint
|
||||
|
|
@ -2456,8 +2459,39 @@ private:
|
|||
slot.n_prompt_tokens_cache = 0;
|
||||
}
|
||||
|
||||
// If using an alora, there may be uncached tokens that come
|
||||
// before the invocation sequence. When this happens, the
|
||||
// tokens before the invocation sequence need to be
|
||||
// processed without the adapter in a separate batch, then
|
||||
// the adapter needs to be enabled for the remaining tokens.
|
||||
if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) {
|
||||
SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
|
||||
const auto & enabled_loras = lora_get_enabled_ids(slot.lora);
|
||||
GGML_ASSERT(enabled_loras.size() == 1);
|
||||
alora_scale = slot.lora[enabled_loras[0]].scale;
|
||||
slot.lora[enabled_loras[0]].scale = 0.0f;
|
||||
alora_disabled_id = enabled_loras[0];
|
||||
}
|
||||
|
||||
bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
|
||||
|
||||
// make checkpoints only for completion tasks
|
||||
do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
|
||||
|
||||
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
||||
// checkpoints are created only if:
|
||||
// - the model uses SWA and we are not using `swa_full`
|
||||
// - the model architecture is marked as recurrent or hybrid
|
||||
//
|
||||
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
||||
do_checkpoint = do_checkpoint && (
|
||||
llama_model_is_recurrent(model) ||
|
||||
llama_model_is_hybrid(model) ||
|
||||
(llama_model_n_swa(model) > 0 && !params_base.swa_full)
|
||||
);
|
||||
|
||||
bool has_mtmd = false;
|
||||
|
||||
// check if we should process the image
|
||||
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
|
||||
// process the image
|
||||
|
|
@ -2478,38 +2512,9 @@ private:
|
|||
slot.prompt.tokens.push_back(chunk.get()); // copy
|
||||
}
|
||||
|
||||
do_checkpoint = false; // do not checkpoint right after an image chunk
|
||||
has_mtmd = true;
|
||||
}
|
||||
|
||||
// If using an alora, there may be uncached tokens that come
|
||||
// before the invocation sequence. When this happens, the
|
||||
// tokens before the invocation sequence need to be
|
||||
// processed without the adapter in a separate batch, then
|
||||
// the adapter needs to be enabled for the remaining tokens.
|
||||
if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) {
|
||||
SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
|
||||
const auto & enabled_loras = lora_get_enabled_ids(slot.lora);
|
||||
GGML_ASSERT(enabled_loras.size() == 1);
|
||||
alora_scale = slot.lora[enabled_loras[0]].scale;
|
||||
slot.lora[enabled_loras[0]].scale = 0.0f;
|
||||
alora_disabled_id = enabled_loras[0];
|
||||
}
|
||||
|
||||
// make checkpoints only for completion tasks
|
||||
do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
|
||||
|
||||
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
||||
// checkpoints are created only if:
|
||||
// - the model uses SWA and we are not using `swa_full`
|
||||
// - the model architecture is marked as recurrent or hybrid
|
||||
//
|
||||
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
||||
do_checkpoint = do_checkpoint && (
|
||||
llama_model_is_recurrent(model) ||
|
||||
llama_model_is_hybrid(model) ||
|
||||
(llama_model_n_swa(model) > 0 && !params_base.swa_full)
|
||||
);
|
||||
|
||||
// add prompt tokens for processing in the current batch
|
||||
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
|
||||
// get next token to process
|
||||
|
|
@ -2541,13 +2546,13 @@ private:
|
|||
// - 4 + n_ubatch
|
||||
// - 4
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/20288
|
||||
{
|
||||
if (do_checkpoint) {
|
||||
static const int checkpoint_offsets[] = {4 + n_ubatch, 4};
|
||||
|
||||
bool should_break = false;
|
||||
for (int offset : checkpoint_offsets) {
|
||||
const int n_last = std::min(n_batch, offset);
|
||||
if (do_checkpoint && slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) {
|
||||
if (slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) {
|
||||
should_break = true;
|
||||
break;
|
||||
}
|
||||
|
|
@ -2604,10 +2609,13 @@ private:
|
|||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
|
||||
|
||||
// no need for empty or small checkpoints
|
||||
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
|
||||
do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.prompt.n_tokens() >= 64);
|
||||
|
||||
// do not checkpoint after mtmd chunks
|
||||
do_checkpoint = do_checkpoint && !has_mtmd;
|
||||
|
||||
// no need to create checkpoints that are too close together
|
||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
|
||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || slot.prompt.n_tokens() - n_tokens_cur > slot.prompt.checkpoints.back().n_tokens + 64);
|
||||
|
||||
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
|
||||
// yet processed and therefore it is not part of the checkpoint.
|
||||
|
|
|
|||
|
|
@ -539,6 +539,22 @@ void server_models::load(const std::string & name) {
|
|||
return;
|
||||
}
|
||||
|
||||
// Re-check capacity under the lock to prevent concurrent loads from
|
||||
// exceeding models_max. Without this, the window between unload_lru()
|
||||
// releasing its lock and this lock_guard acquiring allows multiple
|
||||
// threads to each observe capacity and all proceed to load.
|
||||
if (base_params.models_max > 0) {
|
||||
size_t count_active = 0;
|
||||
for (const auto & m : mapping) {
|
||||
if (m.second.meta.is_active()) {
|
||||
count_active++;
|
||||
}
|
||||
}
|
||||
if (count_active >= (size_t)base_params.models_max) {
|
||||
throw std::runtime_error("model limit reached, try again later");
|
||||
}
|
||||
}
|
||||
|
||||
// prepare new instance info
|
||||
instance_t inst;
|
||||
inst.meta = meta;
|
||||
|
|
@ -606,13 +622,20 @@ void server_models::load(const std::string & name) {
|
|||
});
|
||||
|
||||
std::thread stopping_thread([&]() {
|
||||
// thread to monitor stopping signal
|
||||
// thread to monitor stopping signal OR child crash
|
||||
auto is_stopping = [this, &name]() {
|
||||
return this->stopping_models.find(name) != this->stopping_models.end();
|
||||
};
|
||||
auto should_wake = [&]() {
|
||||
return is_stopping() || !subprocess_alive(child_proc.get());
|
||||
};
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(this->mutex);
|
||||
this->cv_stop.wait(lk, is_stopping);
|
||||
this->cv_stop.wait(lk, should_wake);
|
||||
}
|
||||
// child may have already exited (e.g. crashed) — skip shutdown sequence
|
||||
if (!subprocess_alive(child_proc.get())) {
|
||||
return;
|
||||
}
|
||||
SRV_INF("stopping model instance name=%s\n", name.c_str());
|
||||
// send interrupt to child process
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ json task_params::to_json(bool only_metrics) const {
|
|||
{"chat_format", common_chat_format_name(chat_parser_params.format)},
|
||||
{"reasoning_format", common_reasoning_format_name(chat_parser_params.reasoning_format)},
|
||||
{"reasoning_in_content", chat_parser_params.reasoning_in_content},
|
||||
{"thinking_forced_open", chat_parser_params.thinking_forced_open},
|
||||
{"generation_prompt", chat_parser_params.generation_prompt},
|
||||
{"samplers", samplers},
|
||||
{"speculative.n_max", speculative.n_max},
|
||||
{"speculative.n_min", speculative.n_min},
|
||||
|
|
@ -128,14 +128,14 @@ json task_params::to_json(bool only_metrics) const {
|
|||
{"logit_bias", format_logit_bias(sampling.logit_bias)},
|
||||
{"n_probs", sampling.n_probs},
|
||||
{"min_keep", sampling.min_keep},
|
||||
{"grammar", sampling.grammar},
|
||||
{"grammar", common_grammar_value(sampling.grammar)},
|
||||
{"grammar_lazy", sampling.grammar_lazy},
|
||||
{"grammar_triggers", grammar_triggers},
|
||||
{"preserved_tokens", sampling.preserved_tokens},
|
||||
{"chat_format", common_chat_format_name(chat_parser_params.format)},
|
||||
{"reasoning_format", common_reasoning_format_name(chat_parser_params.reasoning_format)},
|
||||
{"reasoning_in_content", chat_parser_params.reasoning_in_content},
|
||||
{"thinking_forced_open", chat_parser_params.thinking_forced_open},
|
||||
{"generation_prompt", chat_parser_params.generation_prompt},
|
||||
{"samplers", samplers},
|
||||
{"speculative.n_max", speculative.n_max},
|
||||
{"speculative.n_min", speculative.n_min},
|
||||
|
|
@ -376,14 +376,25 @@ task_params server_task::params_from_json_cmpl(
|
|||
try {
|
||||
auto schema = json_value(data, "json_schema", json::object());
|
||||
SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str());
|
||||
params.sampling.grammar = json_schema_to_grammar(schema);
|
||||
SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
|
||||
std::string grammar_str = json_schema_to_grammar(schema);
|
||||
SRV_DBG("Converted grammar: %s\n", grammar_str.c_str());
|
||||
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, std::move(grammar_str)};
|
||||
} catch (const std::exception & e) {
|
||||
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
||||
}
|
||||
} else {
|
||||
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
||||
SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
|
||||
std::string grammar_str = json_value(data, "grammar", std::string());
|
||||
if (!grammar_str.empty()) {
|
||||
// grammar_type key is set by the server when converting chat template grammars
|
||||
std::string grammar_type = json_value(data, "grammar_type", std::string());
|
||||
if (grammar_type == "tool_calls") {
|
||||
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_TOOL_CALLS, std::move(grammar_str)};
|
||||
} else {
|
||||
// explicit grammar from the user (API field "grammar")
|
||||
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, std::move(grammar_str)};
|
||||
}
|
||||
SRV_DBG("Grammar (%s): %s\n", grammar_type.c_str(), common_grammar_value(params.sampling.grammar).c_str());
|
||||
}
|
||||
params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
|
||||
SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
|
||||
}
|
||||
|
|
@ -402,7 +413,9 @@ task_params server_task::params_from_json_cmpl(
|
|||
}
|
||||
params.chat_parser_params.reasoning_format = reasoning_format;
|
||||
params.chat_parser_params.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
|
||||
params.chat_parser_params.thinking_forced_open = json_value(data, "thinking_forced_open", false);
|
||||
params.chat_parser_params.generation_prompt = json_value(data, "generation_prompt", std::string());
|
||||
params.sampling.generation_prompt = params.chat_parser_params.generation_prompt;
|
||||
SRV_DBG("Generation prompt: '%s'\n", params.chat_parser_params.generation_prompt.c_str());
|
||||
params.chat_parser_params.parse_tool_calls = json_value(data, "parse_tool_calls", false);
|
||||
if (data.contains("chat_parser")) {
|
||||
params.chat_parser_params.parser.load(data.at("chat_parser").get<std::string>());
|
||||
|
|
@ -469,10 +482,7 @@ task_params server_task::params_from_json_cmpl(
|
|||
const auto start_tag = json_value(data, "reasoning_budget_start_tag", std::string());
|
||||
const auto end_tag = json_value(data, "reasoning_budget_end_tag", std::string());
|
||||
const auto message = json_value(data, "reasoning_budget_message", std::string());
|
||||
const bool activate_imm = json_value(data, "reasoning_budget_activate_immediately", false);
|
||||
|
||||
params.sampling.reasoning_budget_tokens = budget;
|
||||
params.sampling.reasoning_budget_activate_immediately = activate_imm;
|
||||
|
||||
if (!start_tag.empty()) {
|
||||
params.sampling.reasoning_budget_start = common_tokenize(vocab, start_tag, false, true);
|
||||
|
|
@ -482,8 +492,8 @@ task_params server_task::params_from_json_cmpl(
|
|||
params.sampling.reasoning_budget_forced = common_tokenize(vocab, message + end_tag, false, true);
|
||||
}
|
||||
|
||||
SRV_DBG("reasoning budget: tokens=%d, activate_immediately=%s, start=%zu toks, end=%zu toks, forced=%zu toks\n",
|
||||
budget, activate_imm ? "true" : "false",
|
||||
SRV_DBG("reasoning budget: tokens=%d, generation_prompt='%s', start=%zu toks, end=%zu toks, forced=%zu toks\n",
|
||||
budget, params.sampling.generation_prompt.c_str(),
|
||||
params.sampling.reasoning_budget_start.size(),
|
||||
params.sampling.reasoning_budget_end.size(),
|
||||
params.sampling.reasoning_budget_forced.size());
|
||||
|
|
@ -746,6 +756,15 @@ json server_task_result_cmpl_final::to_json_non_oaicompat() {
|
|||
return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
|
||||
}
|
||||
|
||||
json server_task_result_cmpl_final::usage_json_oaicompat() {
|
||||
return json {
|
||||
{"completion_tokens", n_decoded},
|
||||
{"prompt_tokens", n_prompt_tokens},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens},
|
||||
{"prompt_tokens_details", json { {"cached_tokens", n_prompt_tokens_cache} }},
|
||||
};
|
||||
}
|
||||
|
||||
json server_task_result_cmpl_final::to_json_oaicompat() {
|
||||
std::time_t t = std::time(0);
|
||||
json logprobs = json(nullptr); // OAI default to null
|
||||
|
|
@ -771,11 +790,7 @@ json server_task_result_cmpl_final::to_json_oaicompat() {
|
|||
{"model", oaicompat_model},
|
||||
{"system_fingerprint", build_info},
|
||||
{"object", "text_completion"},
|
||||
{"usage", json {
|
||||
{"completion_tokens", n_decoded},
|
||||
{"prompt_tokens", n_prompt_tokens},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens}
|
||||
}},
|
||||
{"usage", usage_json_oaicompat()},
|
||||
{"id", oaicompat_cmpl_id}
|
||||
};
|
||||
|
||||
|
|
@ -823,11 +838,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat() {
|
|||
{"model", oaicompat_model},
|
||||
{"system_fingerprint", build_info},
|
||||
{"object", "chat.completion"},
|
||||
{"usage", json {
|
||||
{"completion_tokens", n_decoded},
|
||||
{"prompt_tokens", n_prompt_tokens},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens}
|
||||
}},
|
||||
{"usage", usage_json_oaicompat()},
|
||||
{"id", oaicompat_cmpl_id}
|
||||
};
|
||||
|
||||
|
|
@ -892,11 +903,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
|
|||
{"model", oaicompat_model},
|
||||
{"system_fingerprint", build_info},
|
||||
{"object", "chat.completion.chunk"},
|
||||
{"usage", json {
|
||||
{"completion_tokens", n_decoded},
|
||||
{"prompt_tokens", n_prompt_tokens},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens},
|
||||
}},
|
||||
{"usage", usage_json_oaicompat()},
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -975,6 +982,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp() {
|
|||
{"input_tokens", n_prompt_tokens},
|
||||
{"output_tokens", n_decoded},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens},
|
||||
{"input_tokens_details", json { {"cached_tokens", n_prompt_tokens_cache} }},
|
||||
}},
|
||||
};
|
||||
|
||||
|
|
@ -1083,7 +1091,8 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() {
|
|||
{"usage", json {
|
||||
{"input_tokens", n_prompt_tokens},
|
||||
{"output_tokens", n_decoded},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens}
|
||||
{"total_tokens", n_decoded + n_prompt_tokens},
|
||||
{"input_tokens_details", json { {"cached_tokens", n_prompt_tokens_cache} }},
|
||||
}}
|
||||
}},
|
||||
}}
|
||||
|
|
@ -1149,7 +1158,8 @@ json server_task_result_cmpl_final::to_json_anthropic() {
|
|||
{"stop_reason", stop_reason},
|
||||
{"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)},
|
||||
{"usage", {
|
||||
{"input_tokens", n_prompt_tokens},
|
||||
{"cache_read_input_tokens", n_prompt_tokens_cache},
|
||||
{"input_tokens", n_prompt_tokens - n_prompt_tokens_cache},
|
||||
{"output_tokens", n_decoded}
|
||||
}}
|
||||
};
|
||||
|
|
@ -1659,7 +1669,8 @@ json server_task_result_cmpl_partial::to_json_anthropic() {
|
|||
{"stop_reason", nullptr},
|
||||
{"stop_sequence", nullptr},
|
||||
{"usage", {
|
||||
{"input_tokens", n_prompt_tokens},
|
||||
{"cache_read_input_tokens", n_prompt_tokens_cache},
|
||||
{"input_tokens", n_prompt_tokens - n_prompt_tokens_cache},
|
||||
{"output_tokens", 0}
|
||||
}}
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -344,6 +344,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
bool truncated;
|
||||
int32_t n_decoded;
|
||||
int32_t n_prompt_tokens;
|
||||
int32_t n_prompt_tokens_cache;
|
||||
int32_t n_tokens_cached;
|
||||
bool has_new_line;
|
||||
std::string stopping_word;
|
||||
|
|
@ -387,6 +388,8 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
|
||||
json to_json_non_oaicompat();
|
||||
|
||||
json usage_json_oaicompat();
|
||||
|
||||
json to_json_oaicompat();
|
||||
|
||||
json to_json_oaicompat_chat();
|
||||
|
|
@ -408,6 +411,7 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
|
||||
int32_t n_decoded;
|
||||
int32_t n_prompt_tokens;
|
||||
int32_t n_prompt_tokens_cache;
|
||||
|
||||
bool post_sampling_probs;
|
||||
bool is_progress = false;
|
||||
|
|
|
|||
|
|
@ -51,6 +51,27 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
|
|||
assert choice["finish_reason"] == finish_reason
|
||||
|
||||
|
||||
def test_chat_completion_cached_tokens():
|
||||
global server
|
||||
server.n_slots = 1
|
||||
server.start()
|
||||
seq = [
|
||||
("1 2 3 4 5 6", 77, 0),
|
||||
("1 2 3 4 5 6", 77, 76),
|
||||
("1 2 3 4 5 9", 77, 51),
|
||||
("1 2 3 9 9 9", 77, 47),
|
||||
]
|
||||
for user_prompt, n_prompt, n_cache in seq:
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": 8,
|
||||
"messages": [
|
||||
{"role": "system", "content": "Test"},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
})
|
||||
assert res.body["usage"]["prompt_tokens"] == n_prompt
|
||||
assert res.body["usage"]["prompt_tokens_details"]["cached_tokens"] == n_cache
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
|
||||
[
|
||||
|
|
@ -210,6 +231,7 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int
|
|||
def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
|
||||
global server
|
||||
server.jinja = jinja
|
||||
server.debug = True
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": n_predicted,
|
||||
|
|
|
|||
|
|
@ -63,8 +63,10 @@ def test_anthropic_messages_basic():
|
|||
assert "text" in res.body["content"][0], "Text content block missing 'text' field"
|
||||
assert res.body["stop_reason"] in ["end_turn", "max_tokens"], f"Invalid stop_reason: {res.body.get('stop_reason')}"
|
||||
assert "usage" in res.body, "Missing 'usage' field"
|
||||
assert "cache_read_input_tokens" in res.body["usage"], "Missing usage.cache_read_input_tokens"
|
||||
assert "input_tokens" in res.body["usage"], "Missing usage.input_tokens"
|
||||
assert "output_tokens" in res.body["usage"], "Missing usage.output_tokens"
|
||||
assert isinstance(res.body["usage"]["cache_read_input_tokens"], int), "cache_read_input_tokens should be integer"
|
||||
assert isinstance(res.body["usage"]["input_tokens"], int), "input_tokens should be integer"
|
||||
assert isinstance(res.body["usage"]["output_tokens"], int), "output_tokens should be integer"
|
||||
assert res.body["usage"]["output_tokens"] > 0, "Should have generated some tokens"
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ describe('ParameterSyncService', () => {
|
|||
chat_format: '',
|
||||
reasoning_format: '',
|
||||
reasoning_in_content: false,
|
||||
thinking_forced_open: false,
|
||||
generation_prompt: '',
|
||||
'speculative.n_max': 0,
|
||||
'speculative.n_min': 0,
|
||||
'speculative.p_min': 0.0,
|
||||
|
|
@ -116,7 +116,7 @@ describe('ParameterSyncService', () => {
|
|||
chat_format: '',
|
||||
reasoning_format: '',
|
||||
reasoning_in_content: false,
|
||||
thinking_forced_open: false,
|
||||
generation_prompt: '',
|
||||
'speculative.n_max': 0,
|
||||
'speculative.n_min': 0,
|
||||
'speculative.p_min': 0.0,
|
||||
|
|
|
|||
4
tools/server/webui/src/lib/types/api.d.ts
vendored
4
tools/server/webui/src/lib/types/api.d.ts
vendored
|
|
@ -164,7 +164,7 @@ export interface ApiLlamaCppServerProps {
|
|||
chat_format: string;
|
||||
reasoning_format: string;
|
||||
reasoning_in_content: boolean;
|
||||
thinking_forced_open: boolean;
|
||||
generation_prompt: string;
|
||||
samplers: string[];
|
||||
backend_sampling: boolean;
|
||||
'speculative.n_max': number;
|
||||
|
|
@ -332,7 +332,7 @@ export interface ApiSlotData {
|
|||
chat_format: string;
|
||||
reasoning_format: string;
|
||||
reasoning_in_content: boolean;
|
||||
thinking_forced_open: boolean;
|
||||
generation_prompt: string;
|
||||
samplers: string[];
|
||||
backend_sampling: boolean;
|
||||
'speculative.n_max': number;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue