diff --git a/common/arg.cpp b/common/arg.cpp index 3a26f0161..f1da2f950 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1314,6 +1314,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.kv_unified = value; } ).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL})); + add_opt(common_arg( + {"--clear-idle"}, + {"--no-clear-idle"}, + "save and clear idle slots on new task (default: enabled, requires unified KV and cache-ram)", + [](common_params & params, bool value) { + params.clear_idle = value; + } + ).set_env("LLAMA_ARG_CLEAR_IDLE").set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--context-shift"}, {"--no-context-shift"}, diff --git a/common/chat-auto-parser-generator.cpp b/common/chat-auto-parser-generator.cpp index 60b269c42..1b431ed15 100644 --- a/common/chat-auto-parser-generator.cpp +++ b/common/chat-auto-parser-generator.cpp @@ -6,6 +6,7 @@ #include "json-schema-to-grammar.h" #include "log.h" #include "nlohmann/json.hpp" +#include "peg-parser.h" #include #include @@ -317,6 +318,44 @@ common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_cont p.end(); } +common_peg_parser analyze_tools::build_func_parser(common_chat_peg_builder & p, const std::string & name, + const common_peg_parser & call_id_section, bool have_call_id, + const common_peg_parser & args, + std::optional atomic_peek) const { + auto open = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix); + bool matched_atomic = false; + common_peg_parser func_parser = p.eps(); + + if (!function.name_suffix.empty()) { + func_parser = open + call_id_section + p.space() + args; + matched_atomic = true; + } else if (have_call_id) { + func_parser = p.atomic(open + call_id_section) + p.space() + args; + matched_atomic = true; + } else if (atomic_peek.has_value()) { + func_parser = p.atomic(open + call_id_section + p.space() + *atomic_peek) + args; + matched_atomic = true; + } else { + func_parser = open + call_id_section + p.space() + args; + } + + if (!function.close.empty()) { + func_parser = func_parser + p.space() + p.tool_close(p.literal(function.close)); + } else if (!format.per_call_end.empty()) { + // When there's no func_close but there is a per_call_end marker, use peek() to ensure + // we only emit tool_close when we can actually see the closing marker. This prevents + // premature closing during partial parsing when we've seen e.g. "" (end) or "" prefix that failed to match. + func_parser = func_parser + p.tool_close(p.peek(p.literal(format.per_call_end))); + } else { + func_parser = func_parser + p.tool_close(p.space()); // force this to process tool closing callbacks in mapper + } + if (!matched_atomic) { + func_parser = p.atomic(func_parser); + } + return func_parser; +} + common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context & ctx) const { auto & p = ctx.p; const auto & inputs = ctx.inputs; @@ -330,17 +369,27 @@ common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context const auto & schema = func.contains("parameters") ? func.at("parameters") : json::object(); // Build call_id parser based on position (if supported) + bool have_call_id = false; common_peg_parser call_id_section = p.eps(); if (call_id.pos == call_id_position::BETWEEN_FUNC_AND_ARGS && !call_id.prefix.empty() && - !call_id.suffix.empty()) { - call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix))) + call_id.suffix; + (!call_id.suffix.empty() || !arguments.start.empty())) { + if (!call_id.suffix.empty()) { + call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix))) + call_id.suffix; + } else { + call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(arguments.start))); + } + have_call_id = true; + } + auto args_parser = p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)); + if (!arguments.start.empty()) { + args_parser = p.literal(arguments.start) + args_parser; + } + if (!arguments.end.empty()) { + args_parser = args_parser + p.literal(arguments.end); } - auto func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) + - call_id_section + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)); - if (!function.close.empty()) { - func_parser = func_parser + function.close; - } + auto atomic_peek = !arguments.start.empty() ? std::optional(p.peek(p.literal(arguments.start))) : std::nullopt; + auto func_parser = build_func_parser(p, name, call_id_section, have_call_id, args_parser, atomic_peek); tool_choice |= p.rule("tool-" + name, func_parser); }); @@ -400,12 +449,34 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte for (const auto & [param_name, param_schema] : properties.items()) { bool is_required = required.find(param_name) != required.end(); std::string type = "object"; - auto type_obj = param_schema.contains("type") ? param_schema.at("type") : json::object(); - if (type_obj.is_string()) { - type_obj.get_to(type); - } else if (type_obj.is_object()) { - if (type_obj.contains("type") && type_obj.at("type").is_string()) { - type_obj.at("type").get_to(type); + if (param_schema.contains("type")) { + const auto & type_obj = param_schema.at("type"); + if (type_obj.is_string()) { + type_obj.get_to(type); + } else if (type_obj.is_array()) { + // Handle nullable types like ["string", "null"] + for (const auto & t : type_obj) { + if (t.is_string() && t.get() != "null") { + type = t.get(); + break; + } + } + } else if (type_obj.is_object()) { + if (type_obj.contains("type") && type_obj.at("type").is_string()) { + type_obj.at("type").get_to(type); + } + } + } + // Infer string type from enum values when type is unspecified + if (type == "object" && param_schema.contains("enum")) { + const auto & enum_vals = param_schema.at("enum"); + if (enum_vals.is_array()) { + for (const auto & v : enum_vals) { + if (v.is_string()) { + type = "string"; + break; + } + } } } @@ -448,52 +519,31 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte args_seq = args_seq + p.repeat(p.space() + any_opt, 0, (int) optional_parsers.size()); } + if (!arguments.start.empty()) { + args_seq = p.literal(arguments.start) + args_seq; + } + if (!arguments.end.empty()) { + args_seq = args_seq + p.literal(arguments.end); + } + // Build call_id parser based on position (if supported) common_peg_parser call_id_section = p.eps(); bool have_call_id = false; if (call_id.pos == call_id_position::BETWEEN_FUNC_AND_ARGS && !call_id.prefix.empty() && - !call_id.suffix.empty()) { + (!call_id.suffix.empty() || !arguments.start.empty())) { have_call_id = true; - call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix)) + call_id.suffix); - } - - bool matched_atomic = false; - common_peg_parser func_parser = p.eps(); - if (!function.name_suffix.empty()) { - func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) + - call_id_section + p.space() + args_seq; - matched_atomic = true; - } else if (have_call_id) { - func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) + - call_id_section) + p.space() + args_seq; - matched_atomic = true; - } else if (!arguments.name_prefix.empty() && !required_parsers.empty()) { - // Only peek for an arg tag when there are required args that must follow. - // When all args are optional, the model may emit no arg tags at all (#20650). - func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) + - call_id_section + p.space() + p.peek(p.literal(arguments.name_prefix))) + args_seq; - matched_atomic = true; - } else { - func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) + - call_id_section + p.space() + args_seq; - } - - if (!function.close.empty()) { - func_parser = func_parser + p.space() + p.tool_close(p.literal(function.close)); - } else if (!format.per_call_end.empty()) { - // When there's no func_close but there is a per_call_end marker, use peek() to ensure - // we only emit tool_close when we can actually see the closing marker. This prevents - // premature closing during partial parsing when we've seen e.g. "" (end) or "" prefix that failed to match. - func_parser = func_parser + p.tool_close(p.peek(p.literal(format.per_call_end))); - } else { - func_parser = - func_parser + p.tool_close(p.space()); // force this to process tool closing callbacks in mapper - } - if (!matched_atomic) { - func_parser = p.atomic(func_parser); + if (!call_id.suffix.empty()) { + call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix)) + call_id.suffix); + } else { + call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(arguments.start))); + } } + // Only peek for an arg tag when there are required args that must follow. + // When all args are optional, the model may emit no arg tags at all (#20650). + auto atomic_peek = (!arguments.name_prefix.empty() && !required_parsers.empty()) ? + std::optional(p.peek(p.literal(arguments.name_prefix))) : std::nullopt; + auto func_parser = build_func_parser(p, name, call_id_section, have_call_id, args_seq, atomic_peek); tool_choice |= p.rule("tool-" + name, func_parser); }); @@ -574,9 +624,33 @@ common_peg_parser analyze_tools::build_tool_parser_tag_gemma4_dict(parser_build_ std::vector arg_entries; for (const auto & [param_name, param_schema] : properties.items()) { - std::string type = "object"; - auto type_v = param_schema.contains("type") ? param_schema.at("type") : json::object(); - if (type_v.is_string()) type_v.get_to(type); + std::string type = "object"; + if (param_schema.contains("type")) { + const auto & type_v = param_schema.at("type"); + if (type_v.is_string()) { + type_v.get_to(type); + } else if (type_v.is_array()) { + // Handle nullable types like ["string", "null"] + for (const auto & t : type_v) { + if (t.is_string() && t.get() != "null") { + type = t.get(); + break; + } + } + } + } + // Infer string type from enum values when type is unspecified + if (type == "object" && param_schema.contains("enum")) { + const auto & enum_vals = param_schema.at("enum"); + if (enum_vals.is_array()) { + for (const auto & v : enum_vals) { + if (v.is_string()) { + type = "string"; + break; + } + } + } + } common_peg_parser value_parser = p.eps(); if (type == "string") { diff --git a/common/chat-auto-parser.h b/common/chat-auto-parser.h index 8886c330d..2168bb05e 100644 --- a/common/chat-auto-parser.h +++ b/common/chat-auto-parser.h @@ -356,6 +356,13 @@ struct analyze_tools : analyze_base { common_peg_parser build_tool_parser_json_native(parser_build_context & ctx) const; common_peg_parser build_tool_parser_tag_json(parser_build_context & ctx) const; common_peg_parser build_tool_parser_tag_tagged(parser_build_context & ctx) const; + + // Shared helper: builds func_parser from open+call_id+args, handling atomic wrapping and close. + // atomic_peek: if present, used as the peek expression in the third atomicity branch. + common_peg_parser build_func_parser(common_chat_peg_builder & p, const std::string & name, + const common_peg_parser & call_id_section, bool have_call_id, + const common_peg_parser & args, + std::optional atomic_peek) const; common_peg_parser build_tool_parser_tag_gemma4_dict(parser_build_context & ctx) const; }; diff --git a/common/chat-diff-analyzer.cpp b/common/chat-diff-analyzer.cpp index aadade60f..828829663 100644 --- a/common/chat-diff-analyzer.cpp +++ b/common/chat-diff-analyzer.cpp @@ -25,6 +25,9 @@ static const std::string ARG_SECOND = "BB_ARG_SND_BB"; static const std::string USER_MSG = "U_USER_MSG Hello END_U"; static const std::string ASSISTANT_MSG = "A_ASST_MSG I can help END_A"; static const std::string THINKING_CONTENT = "REASON_PART I am thinking END_R"; +static const std::string CALL_ID_001 = "call00001"; +static const std::string CALL_ID_002 = "call00002"; +static const std::string CALL_ID_999 = "call99999"; static std::vector> workarounds( { // Old reasoning Qwen templates - they don't really display reasoning content, but we still want to @@ -131,6 +134,7 @@ static std::vector static std::string mode_to_str(T mode) { @@ -215,6 +219,11 @@ void autoparser::analyze_template(const common_chat_template & tmpl) { LOG_DBG("func_name_prefix: '%s'\n", tools.function.name_prefix.c_str()); LOG_DBG("func_name_suffix: '%s'\n", tools.function.name_suffix.c_str()); LOG_DBG("func_close: '%s'\n", tools.function.close.c_str()); + LOG_DBG("call_id_prefix: '%s'\n", tools.call_id.prefix.c_str()); + LOG_DBG("call_id_suffix: '%s'\n", tools.call_id.suffix.c_str()); + LOG_DBG("call_id_pos: '%s'\n", mode_to_str(tools.call_id.pos).c_str()); + LOG_DBG("args_start: '%s'\n", tools.arguments.start.c_str()); + LOG_DBG("args_end: '%s'\n", tools.arguments.end.c_str()); LOG_DBG("arg_name_prefix: '%s'\n", tools.arguments.name_prefix.c_str()); LOG_DBG("arg_name_suffix: '%s'\n", tools.arguments.name_suffix.c_str()); LOG_DBG("arg_value_prefix: '%s'\n", tools.arguments.value_prefix.c_str()); @@ -583,12 +592,15 @@ analyze_tools::analyze_tools(const common_chat_template & tmpl, if (caps.supports_parallel_tool_calls) { check_per_call_markers(); } + LOG_DBG(ANSI_ORANGE "Phase 3a: Function call analysis\n" ANSI_RESET); extract_function_markers(); + LOG_DBG(ANSI_ORANGE "Phase 3b: Argument analysis\n" ANSI_RESET); if (format.mode == tool_format::TAG_WITH_TAGGED) { analyze_arguments(); } extract_argument_separator(); extract_args_markers(); + LOG_DBG(ANSI_ORANGE "Phase 3c: Call id analysis\n" ANSI_RESET); extract_call_id_markers(); } } @@ -979,8 +991,6 @@ void analyze_tools::extract_function_markers() { } void analyze_tools::analyze_arguments() { - LOG_DBG(ANSI_ORANGE "Phase 4: Argument analysis\n" ANSI_RESET); - extract_argument_name_markers(); extract_argument_value_markers(); } @@ -1189,7 +1199,7 @@ void analyze_tools::extract_args_markers() { const auto & diff = comparison->diff; - if (format.mode != tool_format::JSON_NATIVE) { + if (format.mode == tool_format::JSON_NATIVE) { std::string prefix_marker = !format.section_start.empty() ? format.section_start : format.per_call_start; std::string suffix_marker = !format.section_end.empty() ? format.section_end : format.per_call_end; // these might happen earlier in the tools section as an example or somewhere else, so we need to find the closest ones @@ -1211,6 +1221,10 @@ void analyze_tools::extract_args_markers() { if (find_fun != std::string::npos) { args_start = args_start.substr(find_fun + FUN_FIRST.size(), args_start.size() - find_fun - FUN_FIRST.size()); } + size_t find_call_id = args_start.find(CALL_ID_001); + if (find_call_id != std::string::npos) { + args_start = args_start.substr(find_call_id + CALL_ID_001.size(), args_start.size() - find_call_id - CALL_ID_001.size()); + } arguments.start = args_start; arguments.end = args_end; } @@ -1250,8 +1264,8 @@ void analyze_tools::extract_call_id_markers() { return; } - std::string id_value_1 = "call00001"; - std::string id_value_2 = "call99999"; + std::string id_value_1 = CALL_ID_001; + std::string id_value_2 = CALL_ID_999; size_t common_id_prefix_len = 0; for (size_t i = 0; i < std::min(id_value_1.length(), id_value_2.length()); i++) { @@ -1350,6 +1364,14 @@ void analyze_tools::extract_call_id_markers() { call_id.suffix = find_first_marker(before_func); } + if (call_id.prefix == arguments.end) { + call_id.prefix = ""; + } + + if (call_id.suffix == arguments.start) { + call_id.suffix = ""; + } + // When call_id is detected, per_call_end may have been incorrectly set to include // the call_id_suffix and sample args. Clear it if it starts with call_id_suffix. if (call_id.pos != call_id_position::NONE && !call_id.suffix.empty() && diff --git a/common/chat.cpp b/common/chat.cpp index e7decc952..bc4fc9eb4 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1646,7 +1646,7 @@ static json common_chat_extra_context() { return ctx; } -static std::optional try_specialized_template( +std::optional common_chat_try_specialized_template( const common_chat_template & tmpl, const std::string & src, const autoparser::generation_params & params) { @@ -1793,7 +1793,7 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ return data; } - if (auto result = try_specialized_template(tmpl, src, params)) { + if (auto result = common_chat_try_specialized_template(tmpl, src, params)) { result->generation_prompt = params.generation_prompt; return *result; } diff --git a/common/chat.h b/common/chat.h index a60a9228b..d5328379c 100644 --- a/common/chat.h +++ b/common/chat.h @@ -270,3 +270,8 @@ std::map common_chat_templates_get_caps(const common_chat_tem std::string common_chat_template_direct_apply( const common_chat_template & tmpl, const autoparser::generation_params & inputs); + +std::optional common_chat_try_specialized_template( + const common_chat_template & tmpl, + const std::string & src, + const autoparser::generation_params & params); diff --git a/common/common.h b/common/common.h index 572e6fad5..d39fa757e 100644 --- a/common/common.h +++ b/common/common.h @@ -576,8 +576,9 @@ struct common_params { int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting bool cache_prompt = true; // whether to enable prompt caching - int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot - int32_t checkpoint_every_nt = 8192; // make a checkpoint every n tokens during prefill + bool clear_idle = true; // save and clear idle slots upon starting a new task + int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot + int32_t checkpoint_every_nt = 8192; // make a checkpoint every n tokens during prefill int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. std::string hostname = "127.0.0.1"; diff --git a/common/jinja/runtime.cpp b/common/jinja/runtime.cpp index 2232790c3..5b51427aa 100644 --- a/common/jinja/runtime.cpp +++ b/common/jinja/runtime.cpp @@ -306,6 +306,19 @@ value filter_expression::execute_impl(context & ctx) { filter_id = "strip"; // alias } JJ_DEBUG("Applying filter '%s' to %s", filter_id.c_str(), input->type().c_str()); + // TODO: Refactor filters so this coercion can be done automatically + if (!input->is_undefined() && !is_val(input) && ( + filter_id == "capitalize" || + filter_id == "lower" || + filter_id == "replace" || + filter_id == "strip" || + filter_id == "title" || + filter_id == "upper" || + filter_id == "wordcount" + )) { + JJ_DEBUG("Coercing %s to String for '%s' filter", input->type().c_str(), filter_id.c_str()); + input = mk_val(input->as_string()); + } return try_builtin_func(ctx, filter_id, input)->invoke(func_args(ctx)); } else if (is_stmt(filter)) { diff --git a/common/jinja/value.cpp b/common/jinja/value.cpp index 749113124..7dc1d6540 100644 --- a/common/jinja/value.cpp +++ b/common/jinja/value.cpp @@ -465,8 +465,9 @@ const func_builtins & value_int_t::get_builtins() const { double val = static_cast(args.get_pos(0)->as_int()); return mk_val(val); }}, - {"tojson", tojson}, + {"safe", tojson}, {"string", tojson}, + {"tojson", tojson}, }; return builtins; } @@ -485,8 +486,9 @@ const func_builtins & value_float_t::get_builtins() const { int64_t val = static_cast(args.get_pos(0)->as_float()); return mk_val(val); }}, - {"tojson", tojson}, + {"safe", tojson}, {"string", tojson}, + {"tojson", tojson}, }; return builtins; } @@ -771,6 +773,11 @@ const func_builtins & value_string_t::get_builtins() const { const func_builtins & value_bool_t::get_builtins() const { + static const func_handler tostring = [](const func_args & args) -> value { + args.ensure_vals(); + bool val = args.get_pos(0)->as_bool(); + return mk_val(val ? "True" : "False"); + }; static const func_builtins builtins = { {"default", default_value}, {"int", [](const func_args & args) -> value { @@ -783,11 +790,8 @@ const func_builtins & value_bool_t::get_builtins() const { bool val = args.get_pos(0)->as_bool(); return mk_val(val ? 1.0 : 0.0); }}, - {"string", [](const func_args & args) -> value { - args.ensure_vals(); - bool val = args.get_pos(0)->as_bool(); - return mk_val(val ? "True" : "False"); - }}, + {"safe", tostring}, + {"string", tostring}, {"tojson", tojson}, }; return builtins; @@ -1100,18 +1104,14 @@ const func_builtins & value_object_t::get_builtins() const { } const func_builtins & value_none_t::get_builtins() const { + static const func_handler tostring = [](const func_args &) -> value { + return mk_val("None"); + }; static const func_builtins builtins = { {"default", default_value}, {"tojson", tojson}, - {"string", [](const func_args &) -> value { - return mk_val("None"); - }}, - {"safe", [](const func_args &) -> value { - return mk_val("None"); - }}, - {"strip", [](const func_args &) -> value { - return mk_val("None"); - }}, + {"string", tostring}, + {"safe", tostring}, {"items", empty_value_fn}, {"map", empty_value_fn}, {"reject", empty_value_fn}, diff --git a/common/peg-parser.cpp b/common/peg-parser.cpp index 694f9b850..86faacd61 100644 --- a/common/peg-parser.cpp +++ b/common/peg-parser.cpp @@ -1561,7 +1561,23 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo if (!s.schema) { return true; } - if (s.raw && s.schema->contains("type") && s.schema->at("type").is_string() && s.schema->at("type") == "string") { + if (s.raw && s.schema->contains("type")) { + const auto & type_val = s.schema->at("type"); + if (type_val.is_string() && type_val == "string") { + return true; + } + // Handle nullable types like ["string", "null"] - delegate when the + // non-null type is string, since the tagged format uses raw text + if (type_val.is_array()) { + for (const auto & t : type_val) { + if (t.is_string() && t.get() != "null") { + return t.get() == "string"; + } + } + } + } + // Delegate for enum schemas in raw mode - enum values are literal strings + if (s.raw && !s.schema->contains("type") && s.schema->contains("enum")) { return true; } return false; diff --git a/tools/parser/debug-template-parser.cpp b/tools/parser/debug-template-parser.cpp index a83797157..9c591a1f1 100644 --- a/tools/parser/debug-template-parser.cpp +++ b/tools/parser/debug-template-parser.cpp @@ -5,15 +5,15 @@ #include "gguf.h" #include "jinja/runtime.h" #include "log.h" +#include "nlohmann/json.hpp" +#include "peg-parser.h" #include #include +#include #include #include -#include "nlohmann/json.hpp" -#include "peg-parser.h" - using json = nlohmann::ordered_json; enum class output_mode { @@ -34,14 +34,14 @@ enum class input_message_type { }; struct debug_options { - std::string template_path; - bool with_tools = true; - bool generation_prompt = true; - bool enable_reasoning = true; - bool debug_jinja = false; - bool force_tool_call = false; - output_mode mode = output_mode::BOTH; - input_message_type input_message = input_message_type::NONE; + std::string template_path; + bool with_tools = true; + bool generation_prompt = true; + bool enable_reasoning = true; + bool debug_jinja = false; + bool force_tool_call = false; + output_mode mode = output_mode::BOTH; + input_message_type input_message = input_message_type::NONE; }; static std::string read_file(const std::string & path) { @@ -274,7 +274,7 @@ static void render_scenario(const common_chat_template & tmpl, json final_messages = messages; if (add_generation_prompt && !messages.empty() && messages.back().value("role", "") == "assistant") { final_messages.push_back(json{ - { "role", "user" }, + { "role", "user" }, { "content", "Now please continue with another response." } }); } @@ -305,7 +305,7 @@ static void render_all_scenarios(const common_chat_template & tmpl, const json & tools, bool add_generation_prompt, bool enable_thinking, - input_message_type message_type) { + input_message_type message_type) { json user_msg = build_user_message(); auto render_if = [&](input_message_type type, const std::string & name, const json & assistant_msg) { @@ -335,6 +335,24 @@ static void render_all_scenarios(const common_chat_template & tmpl, } } +static autoparser::generation_params prepare_params(const debug_options & opts, const json & tools) { + autoparser::generation_params params; + params.messages = json::array({ build_user_message() }); + params.reasoning_format = opts.enable_reasoning ? COMMON_REASONING_FORMAT_DEEPSEEK : COMMON_REASONING_FORMAT_NONE; + params.enable_thinking = opts.enable_reasoning; + params.add_generation_prompt = opts.generation_prompt; + + if (opts.with_tools) { + params.tools = tools; + params.tool_choice = opts.force_tool_call ? COMMON_CHAT_TOOL_CHOICE_REQUIRED : COMMON_CHAT_TOOL_CHOICE_AUTO; + } else { + params.tools = json(); + params.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE; + } + params.parallel_tool_calls = false; + return params; +} + int main(int argc, char ** argv) { // Set log level to most verbose to capture all debug output common_log_set_verbosity_thold(99); @@ -369,49 +387,41 @@ int main(int argc, char ** argv) { try { common_chat_template chat_template(template_source, "", ""); - // Build tools definition json tools = opts.with_tools ? build_tools_definition() : json(); - // Render template scenarios if requested - if (opts.input_message != input_message_type::NONE && - (opts.mode == output_mode::TEMPLATE || opts.mode == output_mode::BOTH)) { + autoparser::generation_params params = prepare_params(opts, tools); + common_chat_params parser_data; + if (std::optional spec_tmpl = + common_chat_try_specialized_template(chat_template, template_source, params)) { LOG_ERR("\n"); - LOG_ERR("================================================================================\n"); - LOG_ERR(" TEMPLATE RENDERING OUTPUT\n"); - LOG_ERR("================================================================================\n"); + LOG_ERR("This template uses a specialized parser, analysis results will not be available."); + parser_data = *spec_tmpl; + } else { + // Render template scenarios if requested + if (opts.input_message != input_message_type::NONE && + (opts.mode == output_mode::TEMPLATE || opts.mode == output_mode::BOTH)) { + LOG_ERR("\n"); + LOG_ERR("================================================================================\n"); + LOG_ERR(" TEMPLATE RENDERING OUTPUT\n"); + LOG_ERR("================================================================================\n"); - render_all_scenarios(chat_template, tools, opts.generation_prompt, opts.enable_reasoning, - opts.input_message); - } - - // Output analysis if requested - if (opts.mode == output_mode::ANALYSIS || opts.mode == output_mode::BOTH) { - LOG_ERR("\n"); - LOG_ERR("================================================================================\n"); - LOG_ERR(" TEMPLATE ANALYSIS\n"); - LOG_ERR("================================================================================\n"); - - autoparser::autoparser analysis; - analysis.analyze_template(chat_template); - - // Generate Parser - autoparser::generation_params params; - params.messages = json::array({ build_user_message() }); - params.reasoning_format = - opts.enable_reasoning ? COMMON_REASONING_FORMAT_DEEPSEEK : COMMON_REASONING_FORMAT_NONE; - params.enable_thinking = opts.enable_reasoning; - params.add_generation_prompt = opts.generation_prompt; - - if (opts.with_tools) { - params.tools = tools; - params.tool_choice = opts.force_tool_call ? COMMON_CHAT_TOOL_CHOICE_REQUIRED : COMMON_CHAT_TOOL_CHOICE_AUTO; - } else { - params.tools = json(); - params.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE; + render_all_scenarios(chat_template, tools, opts.generation_prompt, opts.enable_reasoning, + opts.input_message); } - params.parallel_tool_calls = false; - auto parser_data = autoparser::peg_generator::generate_parser(chat_template, params, analysis); + // Output analysis if requested + if (opts.mode == output_mode::ANALYSIS || opts.mode == output_mode::BOTH) { + LOG_ERR("\n"); + LOG_ERR("================================================================================\n"); + LOG_ERR(" TEMPLATE ANALYSIS\n"); + LOG_ERR("================================================================================\n"); + + autoparser::autoparser analysis; + analysis.analyze_template(chat_template); + + // Generate Parser + parser_data = autoparser::peg_generator::generate_parser(chat_template, params, analysis); + } LOG_ERR("\n=== Generated Parser ===\n"); common_peg_arena arena; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 6f737d94d..bd2552f75 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -605,6 +605,17 @@ private: llama_batch_free(batch); } + void slot_save_and_clear(server_slot & slot) { + if (slot.prompt.n_tokens() == 0) { + return; + } + SLT_INF(slot, "%s", "saving idle slot to prompt cache\n"); + SLT_DBG(slot, "%s", "__TEST_TAG_CLEAR_IDLE_SLOT__\n"); + slot.prompt_save(*prompt_cache); + slot.prompt_clear(false); + prompt_cache->update(); + } + void handle_sleeping_state(bool new_state) { GGML_ASSERT(sleeping != new_state); if (new_state) { @@ -864,6 +875,19 @@ private: metrics.init(); + if (params_base.clear_idle) { + if (!params_base.kv_unified) { + SRV_WRN("%s: --clear-idle requires --kv-unified, disabling\n", __func__); + params_base.clear_idle = false; + } else if (params_base.cache_ram_mib == 0) { + SRV_WRN("%s: --clear-idle requires --cache-ram, disabling\n", __func__); + params_base.clear_idle = false; + } else { + SRV_INF("%s: idle slots will be saved to prompt cache and cleared upon starting a new task\n", __func__); + SRV_DBG("%s", "__TEST_TAG_CLEAR_IDLE_ENABLED__\n"); + } + } + // populate webui settings { if (!params_base.webui_config_json.empty()) { @@ -1010,15 +1034,15 @@ private: // cache prompts only for completion tasks update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION; - // don't update the cache if the slot's context is empty - update_cache = update_cache && tokens.size() > 0; - if (update_cache) { SRV_WRN("%s", "updating prompt cache\n"); const int64_t t_start = ggml_time_us(); - ret->prompt_save(*prompt_cache); + // don't save the slot's state if its context is empty + if (tokens.size() > 0) { + ret->prompt_save(*prompt_cache); + } if (!ret->prompt_load(*prompt_cache, task.tokens)) { ret->prompt_clear(false); @@ -1692,9 +1716,7 @@ private: const int id_slot = task.id_slot; const int id_task = task.id; - server_slot * slot = id_slot != -1 - ? get_slot_by_id(id_slot) - : get_available_slot(task); + server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); // // slot scheduling logic @@ -1731,6 +1753,14 @@ private: SRV_ERR("failed to launch slot with task, id_task = %d\n", id_task); break; // drop the task } + + if (params_base.clear_idle) { + for (auto & s : slots) { + if (!s.is_processing()) { + slot_save_and_clear(s); + } + } + } } break; case SERVER_TASK_TYPE_CANCEL: { diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 3018ac90f..4cc87bc50 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -2008,7 +2008,7 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) { const int lcp_best = prompt.tokens.get_common_prefix(tokens_new); - float f_keep_best = float(lcp_best) / prompt.tokens.size(); + float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins float sim_best = float(lcp_best) / tokens_new.size(); SRV_WRN(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); diff --git a/tools/server/tests/unit/test_kv_keep_only_active.py b/tools/server/tests/unit/test_kv_keep_only_active.py new file mode 100644 index 000000000..da93d5001 --- /dev/null +++ b/tools/server/tests/unit/test_kv_keep_only_active.py @@ -0,0 +1,115 @@ +import os +import tempfile +import pytest +from utils import * + +server = ServerPreset.tinyllama2() + +class LogReader: + def __init__(self, path): + self.path = path + self.pos = 0 + def drain(self): + with open(self.path) as f: + f.seek(self.pos) + content = f.read() + self.pos = f.tell() + return content + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + server.n_slots = 2 + server.n_predict = 4 + server.temperature = 0.0 + server.server_slots = True + server.cache_ram = 100 + server.kv_unified = True + server.debug = True + fd, server.log_path = tempfile.mkstemp(suffix='.log') + os.close(fd) + yield + + +LONG_PROMPT = ( + "Once upon a time in a land far away, there lived a brave knight " + "who traveled across mountains and rivers to find the legendary " + "golden sword hidden deep within the enchanted forest of whispers. " + "He met many creatures along the way including dragons and fairies " + "and wizards who helped him on his noble quest to save the kingdom." +) + + +# idle slot cleared on launch should restore from cache-ram +def test_clear_and_restore(): + global server + server.start() + log = LogReader(server.log_path) + + # verify feature is enabled + assert "__TEST_TAG_CLEAR_IDLE_ENABLED__" in log.drain() + + res = server.make_request("POST", "/completion", data={ + "prompt": LONG_PROMPT, + "id_slot": 0, + "cache_prompt": True, + }) + assert res.status_code == 200 + original_prompt_n = res.body["timings"]["prompt_n"] + + # Slot 0 is the only slot with KV — should NOT be cleared + assert "__TEST_TAG_CLEAR_IDLE_SLOT__" not in log.drain() + + # Launching slot 1 clears idle slot 0 + res = server.make_request("POST", "/completion", data={ + "prompt": "The quick brown fox", + "id_slot": 1, + "cache_prompt": True, + }) + assert res.status_code == 200 + assert "__TEST_TAG_CLEAR_IDLE_SLOT__" in log.drain() + + # Re-send same prompt — should restore from cache-ram + res = server.make_request("POST", "/completion", data={ + "prompt": LONG_PROMPT, + "cache_prompt": True, + }) + assert res.status_code == 200 + assert "updating prompt cache" in log.drain() + assert res.body["timings"]["cache_n"] > 0 + assert res.body["timings"]["prompt_n"] < original_prompt_n + + # Follow-up — slot 0 kept its KV, no clearing needed + res = server.make_request("POST", "/completion", data={ + "prompt": LONG_PROMPT + " The knight finally reached the castle gates.", + "cache_prompt": True, + }) + assert res.status_code == 200 + assert "__TEST_TAG_CLEAR_IDLE_SLOT__" not in log.drain() + + +def test_disabled_with_flag(): + global server + server.no_clear_idle = True + server.start() + log = LogReader(server.log_path) + + # Feature should not be enabled + assert "__TEST_TAG_CLEAR_IDLE_ENABLED__" not in log.drain() + + res = server.make_request("POST", "/completion", data={ + "prompt": LONG_PROMPT, + "id_slot": 0, + "cache_prompt": True, + }) + assert res.status_code == 200 + + # Request on different slot — should NOT trigger clearing + res = server.make_request("POST", "/completion", data={ + "prompt": "The quick brown fox", + "id_slot": 1, + "cache_prompt": True, + }) + assert res.status_code == 200 + assert "__TEST_TAG_CLEAR_IDLE_SLOT__" not in log.drain() diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index a9a7e3c4f..5ddac5be4 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -102,6 +102,9 @@ class ServerProcess: mmproj_url: str | None = None media_path: str | None = None sleep_idle_seconds: int | None = None + cache_ram: int | None = None + no_clear_idle: bool = False + log_path: str | None = None webui_mcp_proxy: bool = False # session variables @@ -237,6 +240,10 @@ class ServerProcess: server_args.extend(["--media-path", self.media_path]) if self.sleep_idle_seconds is not None: server_args.extend(["--sleep-idle-seconds", self.sleep_idle_seconds]) + if self.cache_ram is not None: + server_args.extend(["--cache-ram", self.cache_ram]) + if self.no_clear_idle: + server_args.append("--no-clear-idle") if self.webui_mcp_proxy: server_args.append("--webui-mcp-proxy") @@ -249,11 +256,16 @@ class ServerProcess: flags |= subprocess.CREATE_NEW_PROCESS_GROUP flags |= subprocess.CREATE_NO_WINDOW + if self.log_path: + self._log = open(self.log_path, "w") + else: + self._log = sys.stdout + self.process = subprocess.Popen( [str(arg) for arg in [server_path, *server_args]], creationflags=flags, - stdout=sys.stdout, - stderr=sys.stdout, + stdout=self._log, + stderr=self._log if self._log != sys.stdout else sys.stdout, env={**os.environ, "LLAMA_CACHE": "tmp"} if "LLAMA_CACHE" not in os.environ else None, ) server_instances.add(self) @@ -298,6 +310,8 @@ class ServerProcess: except Exception as e: print(f"Error waiting for server: {e}") self.process = None + if hasattr(self, '_log') and self._log != sys.stdout: + self._log.close() def make_request( self,