diff --git a/CMakeLists.txt b/CMakeLists.txt index 631b71271..ec8b96ebd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -460,6 +460,8 @@ add_library(common2 src/unicode-data.cpp otherarch/utils.cpp otherarch/utils.h + common/reasoning-budget.cpp + common/reasoning-budget.h tools/mtmd/mtmd-audio.cpp tools/mtmd/mtmd-audio.h) target_include_directories(common2 PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools ./vendor/stb ./vendor ./otherarch/sdcpp ./otherarch/sdcpp/thirdparty ./tools ./common) diff --git a/Makefile b/Makefile index 6874bb99c..cfcb49611 100644 --- a/Makefile +++ b/Makefile @@ -110,10 +110,10 @@ endif CUBLASLD_FLAGS = CUBLAS_OBJS = -OBJS_FULL += ggml-alloc.o ggml-cpu-traits.o ggml-quants.o ggml-cpu-quants.o kcpp-quantmapper.o kcpp-repackmapper.o unicode.o unicode-common.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm.o common.o llama-impl.o sampling.o kcpputils.o mtmdaudio.o -OBJS_SIMPLE += ggml-alloc.o ggml-cpu-traits.o ggml-quants_noavx2.o ggml-cpu-quants.o kcpp-quantmapper_noavx2.o kcpp-repackmapper_noavx2.o unicode.o unicode-common.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_noavx2.o common.o llama-impl.o sampling.o kcpputils.o mtmdaudio.o -OBJS_SIMPLER += ggml-alloc.o ggml-cpu-traits.o ggml-quants_noavx1.o ggml-cpu-quants.o kcpp-quantmapper_noavx1.o kcpp-repackmapper_noavx1.o unicode.o unicode-common.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_noavx1.o common.o llama-impl.o sampling.o kcpputils.o mtmdaudio.o -OBJS_FAILSAFE += ggml-alloc.o ggml-cpu-traits.o ggml-quants_failsafe.o ggml-cpu-quants.o kcpp-quantmapper_failsafe.o kcpp-repackmapper_failsafe.o unicode.o unicode-common.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_failsafe.o common.o llama-impl.o sampling.o kcpputils.o mtmdaudio.o +OBJS_FULL += ggml-alloc.o ggml-cpu-traits.o ggml-quants.o ggml-cpu-quants.o kcpp-quantmapper.o kcpp-repackmapper.o unicode.o unicode-common.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm.o common.o llama-impl.o sampling.o budget.o kcpputils.o mtmdaudio.o +OBJS_SIMPLE += ggml-alloc.o ggml-cpu-traits.o ggml-quants_noavx2.o ggml-cpu-quants.o kcpp-quantmapper_noavx2.o kcpp-repackmapper_noavx2.o unicode.o unicode-common.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_noavx2.o common.o llama-impl.o sampling.o budget.o kcpputils.o mtmdaudio.o +OBJS_SIMPLER += ggml-alloc.o ggml-cpu-traits.o ggml-quants_noavx1.o ggml-cpu-quants.o kcpp-quantmapper_noavx1.o kcpp-repackmapper_noavx1.o unicode.o unicode-common.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_noavx1.o common.o llama-impl.o sampling.o budget.o kcpputils.o mtmdaudio.o +OBJS_FAILSAFE += ggml-alloc.o ggml-cpu-traits.o ggml-quants_failsafe.o ggml-cpu-quants.o kcpp-quantmapper_failsafe.o kcpp-repackmapper_failsafe.o unicode.o unicode-common.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_failsafe.o common.o llama-impl.o sampling.o budget.o kcpputils.o mtmdaudio.o # OS specific ifeq ($(UNAME_S),Linux) @@ -675,7 +675,8 @@ expose.o: expose.cpp expose.h model_adapter.cpp $(CXX) $(CXXFLAGS) -c $< -o $@ llama-impl.o: src/llama-impl.cpp src/llama-impl.h $(CXX) $(CXXFLAGS) -c $< -o $@ - +budget.o: common/reasoning-budget.cpp common/reasoning-budget.h + $(CXX) $(CXXFLAGS) -c $< -o $@ # sd.cpp objects sdcpp_default.o: otherarch/sdcpp/sdtype_adapter.cpp otherarch/sdcpp/stable-diffusion.h otherarch/sdcpp/stable-diffusion.cpp otherarch/sdcpp/util.cpp otherarch/sdcpp/upscaler.cpp otherarch/sdcpp/model.cpp otherarch/sdcpp/name_conversion.cpp otherarch/sdcpp/tokenize_util.cpp otherarch/sdcpp/thirdparty/zip.c diff --git a/common/arg.cpp b/common/arg.cpp index 232eb4573..0c5f3556e 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2916,6 +2916,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { auto parsed = json::parse(value); for (const auto & item : parsed.items()) { + if (item.key() == "enable_thinking") { + LOG_WRN("Setting 'enable_thinking' via --chat-template-kwargs is deprecated. " + "Use --reasoning on / --reasoning off instead.\n"); + } params.default_template_kwargs[item.key()] = item.value().dump(); } } @@ -3051,14 +3055,39 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.reasoning_format = common_reasoning_format_from_name(value); } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK")); + add_opt(common_arg( + {"-rea", "--reasoning"}, "[on|off|auto]", + "Use reasoning/thinking in the chat ('on', 'off', or 'auto', default: 'auto' (detect from template))", + [](common_params & params, const std::string & value) { + if (is_truthy(value)) { + params.enable_reasoning = 1; + params.default_template_kwargs["enable_thinking"] = "true"; + } else if (is_falsey(value)) { + params.enable_reasoning = 0; + params.default_template_kwargs["enable_thinking"] = "false"; + } else if (is_autoy(value)) { + params.enable_reasoning = -1; + } else { + throw std::invalid_argument( + string_format("error: unknown value for --reasoning: '%s'\n", value.c_str())); + } + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_REASONING")); add_opt(common_arg( {"--reasoning-budget"}, "N", - "controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)", + "token budget for thinking: -1 for unrestricted, 0 for immediate end, N>0 for token budget (default: -1)", [](common_params & params, int value) { - if (value != 0 && value != -1) { throw std::invalid_argument("invalid value"); } + if (value < -1) { throw std::invalid_argument("invalid value"); } params.reasoning_budget = value; } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET")); + add_opt(common_arg( + {"--reasoning-budget-message"}, "MESSAGE", + "message injected before the end-of-thinking tag when reasoning budget is exhausted (default: none)", + [](common_params & params, const std::string & value) { + params.reasoning_budget_message = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE")); add_opt(common_arg( {"--chat-template"}, "JINJA_TEMPLATE", string_format( diff --git a/common/chat-auto-parser-generator.cpp b/common/chat-auto-parser-generator.cpp index 1c74ad30d..b7cf51394 100644 --- a/common/chat-auto-parser-generator.cpp +++ b/common/chat-auto-parser-generator.cpp @@ -135,7 +135,9 @@ common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) co if (thinking_forced_open || thinking_forced_closed) { // Thinking is forced open OR forced closed with enable_thinking=true // In both cases, expect only the closing tag (opening was in template) - return p.reasoning(p.until(end)) + end; + // However, since we might have incorrectly detected the open/close pattern, + // we admit an optional starting marker + return p.optional(p.literal(start)) + p.reasoning(p.until(end)) + end; } if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) { // Standard tag-based reasoning OR tools-only mode (reasoning appears with tools) diff --git a/common/chat-peg-parser.cpp b/common/chat-peg-parser.cpp index 62689adee..4c5bb6218 100644 --- a/common/chat-peg-parser.cpp +++ b/common/chat-peg-parser.cpp @@ -6,7 +6,7 @@ #include -// using json = nlohmann::ordered_json; +using ordered_json = nlohmann::ordered_json; static std::string_view trim_trailing_space(std::string_view sv, int max = -1) { int count = 0; @@ -68,7 +68,7 @@ static int json_brace_depth(const std::string & s) { // JSON-escape a string and return the inner content (without surrounding quotes). static std::string escape_json_string_inner(const std::string & s) { - std::string escaped = json(s).dump(); + std::string escaped = ordered_json(s).dump(); if (escaped.size() >= 2 && escaped.front() == '"' && escaped.back() == '"') { return escaped.substr(1, escaped.size() - 2); } @@ -309,7 +309,7 @@ void common_chat_peg_mapper::map(const common_peg_ast_node & node) { if (arg_count > 0) { arg_entry = ","; } - arg_entry += json(trim(node.text)).dump() + ":"; + arg_entry += ordered_json(trim(node.text)).dump() + ":"; ++arg_count; auto & target = args_target(); @@ -343,7 +343,7 @@ void common_chat_peg_mapper::map(const common_peg_ast_node & node) { // Try to parse as JSON value (number, bool, null, object, array) try { - json parsed = json::parse(value_content); + ordered_json parsed = ordered_json::parse(value_content); if (parsed.is_string()) { // Don't add closing quote yet (added by arg_close) for monotonic streaming std::string escaped = parsed.dump(); @@ -408,7 +408,7 @@ void common_chat_peg_mapper::map(const common_peg_ast_node & node) { common_peg_parser common_chat_peg_builder::standard_constructed_tools( const std::map & markers, - const nlohmann::json & tools, + const ordered_json & tools, bool parallel_tool_calls, bool force_tool_calls) { if (!tools.is_array() || tools.empty()) { @@ -439,7 +439,7 @@ common_peg_parser common_chat_peg_builder::standard_constructed_tools( } const auto & function = tool_def.at("function"); std::string name = function.at("name"); - nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object(); + ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object(); // Build argument parsers auto args = eps(); @@ -479,8 +479,8 @@ common_peg_parser common_chat_peg_builder::standard_constructed_tools( // Python-style tool calls: name(arg1="value1", arg2=123) // Used only by LFM2 for now, so we don't merge it into autoparser common_peg_parser common_chat_peg_builder::python_style_tool_calls( - const nlohmann::json & tools, - bool parallel_tool_calls) { + const ordered_json & tools, + bool parallel_tool_calls) { if (!tools.is_array() || tools.empty()) { return eps(); } @@ -493,7 +493,7 @@ common_peg_parser common_chat_peg_builder::python_style_tool_calls( } const auto & function = tool_def.at("function"); std::string name = function.at("name"); - nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object(); + ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object(); auto args = eps(); if (params.contains("properties") && !params["properties"].empty()) { @@ -555,11 +555,11 @@ static std::pair parse_key_spec(const std::string & ke // Mode 1: function_is_key — parse {"function_name": {...}} common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key( - const nlohmann::json & tools, - const std::string & args_key, - const std::string & effective_args_key, - const std::string & call_id_key, - const std::string & gen_call_id_key) { + const ordered_json & tools, + const std::string & args_key, + const std::string & effective_args_key, + const std::string & call_id_key, + const std::string & gen_call_id_key) { auto tool_choices = choice(); @@ -569,7 +569,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key( } const auto & function = tool_def.at("function"); std::string name = function.at("name"); - nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object(); + ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object(); // Build inner object fields std::vector inner_fields; @@ -634,11 +634,11 @@ common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key( // Mode 2: Nested keys (dot notation like "function.name") common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys( - const nlohmann::json & tools, - const std::string & effective_name_key, - const std::string & effective_args_key, - const std::string & call_id_key, - const std::string & gen_call_id_key) { + const ordered_json & tools, + const std::string & effective_name_key, + const std::string & effective_args_key, + const std::string & call_id_key, + const std::string & gen_call_id_key) { auto tool_choices = choice(); @@ -655,7 +655,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys( } const auto & function = tool_def.at("function"); std::string name = function.at("name"); - nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object(); + ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object(); auto nested_name = literal("\"" + nested_name_field + "\"") + space() + literal(":") + space() + literal("\"") + tool_name(literal(name)) + literal("\""); @@ -706,7 +706,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys( // Mode 3: Flat keys with optional ID fields and parameter ordering common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys( - const nlohmann::json & tools, + const ordered_json & tools, const std::string & effective_name_key, const std::string & effective_args_key, const std::string & call_id_key, @@ -723,7 +723,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys( } const auto & function = tool_def.at("function"); std::string name = function.at("name"); - nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object(); + ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object(); auto tool_name_ = name_key_parser + space() + literal(":") + space() + literal("\"") + tool_name(literal(name)) + literal("\""); @@ -791,7 +791,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys( common_peg_parser common_chat_peg_builder::standard_json_tools( const std::string & section_start, const std::string & section_end, - const nlohmann::json & tools, + const ordered_json & tools, bool parallel_tool_calls, bool force_tool_calls, const std::string & name_key, diff --git a/common/chat-peg-parser.h b/common/chat-peg-parser.h index 5ea14be03..a497508d2 100644 --- a/common/chat-peg-parser.h +++ b/common/chat-peg-parser.h @@ -94,7 +94,7 @@ class common_chat_peg_builder : public common_peg_parser_builder { // parameters_order: order in which JSON fields should be parsed common_peg_parser standard_json_tools(const std::string & section_start, const std::string & section_end, - const nlohmann::json & tools, + const nlohmann::ordered_json & tools, bool parallel_tool_calls, bool force_tool_calls, const std::string & name_key = "", @@ -108,30 +108,30 @@ class common_chat_peg_builder : public common_peg_parser_builder { // Legacy-compatible helper for building XML/tagged style tool calls // Used by tests and manual parsers common_peg_parser standard_constructed_tools(const std::map & markers, - const nlohmann::json & tools, + const nlohmann::ordered_json & tools, bool parallel_tool_calls, bool force_tool_calls); // Helper for Python-style function call format: name(arg1="value1", arg2=123) // Used by LFM2 and similar templates - common_peg_parser python_style_tool_calls(const nlohmann::json & tools, - bool parallel_tool_calls); + common_peg_parser python_style_tool_calls(const nlohmann::ordered_json & tools, + bool parallel_tool_calls); private: // Implementation helpers for standard_json_tools — one per JSON tool call layout mode - common_peg_parser build_json_tools_function_is_key(const nlohmann::json & tools, - const std::string & args_key, - const std::string & effective_args_key, - const std::string & call_id_key, - const std::string & gen_call_id_key); + common_peg_parser build_json_tools_function_is_key(const nlohmann::ordered_json & tools, + const std::string & args_key, + const std::string & effective_args_key, + const std::string & call_id_key, + const std::string & gen_call_id_key); - common_peg_parser build_json_tools_nested_keys(const nlohmann::json & tools, - const std::string & effective_name_key, - const std::string & effective_args_key, - const std::string & call_id_key, - const std::string & gen_call_id_key); + common_peg_parser build_json_tools_nested_keys(const nlohmann::ordered_json & tools, + const std::string & effective_name_key, + const std::string & effective_args_key, + const std::string & call_id_key, + const std::string & gen_call_id_key); - common_peg_parser build_json_tools_flat_keys(const nlohmann::json & tools, + common_peg_parser build_json_tools_flat_keys(const nlohmann::ordered_json & tools, const std::string & effective_name_key, const std::string & effective_args_key, const std::string & call_id_key, diff --git a/common/chat.cpp b/common/chat.cpp index 31c29e1ce..f4639ac46 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -8,6 +8,7 @@ #include "log.h" #include "json-partial.cpp" #include "regex-partial.cpp" +#include "reasoning-budget.h" #include "chat-auto-parser-generator.cpp" #include "chat-auto-parser-helpers.cpp" #include "chat-diff-analyzer.cpp" @@ -871,7 +872,9 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; auto include_grammar = true; - data.supports_thinking = true; + data.supports_thinking = true; + data.thinking_start_tag = "[THINK]"; + data.thinking_end_tag = "[/THINK]"; data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.preserved_tokens = { @@ -1179,9 +1182,11 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp const autoparser::templates_params & inputs) { common_chat_params data; - data.prompt = common_chat_template_direct_apply(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; - data.supports_thinking = true; + data.prompt = common_chat_template_direct_apply(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.supports_thinking = true; + data.thinking_start_tag = ""; + data.thinking_end_tag = ""; data.preserved_tokens = { "<|tool_calls_section_begin|>", "<|tool_calls_section_end|>", @@ -1541,6 +1546,16 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ autoparser.analyze_template(tmpl); auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser); auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE; + if (auto_params.supports_thinking) { + auto_params.thinking_start_tag = autoparser.reasoning.start; + auto_params.thinking_end_tag = autoparser.reasoning.end; + // FORCED_OPEN and FORCED_CLOSED both put in the generation prompt + // (FORCED_CLOSED forces empty when thinking is disabled, + // but forces open when thinking is enabled) + auto_params.thinking_forced_open = + autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_OPEN || + autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_CLOSED; + } return auto_params; } catch (const std::exception & e) { throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what()); @@ -1634,8 +1649,8 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars build_chat_peg_parser([](common_chat_peg_builder & p) { return p.content(p.rest()) + p.end(); }) : src_parser; - if (src_parser.empty()) { - LOG_WRN("No parser definition detected, assuming pure content parser."); + if (src_parser.empty()) { + LOG_DBG("No parser definition detected, assuming pure content parser."); } LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), input.c_str()); diff --git a/common/chat.h b/common/chat.h index 005cc5c8b..930987cf7 100644 --- a/common/chat.h +++ b/common/chat.h @@ -213,6 +213,8 @@ struct common_chat_params { bool grammar_lazy = false; bool thinking_forced_open = false; bool supports_thinking = false; + std::string thinking_start_tag; // e.g., "" + std::string thinking_end_tag; // e.g., "" std::vector grammar_triggers; std::vector preserved_tokens; std::vector additional_stops; diff --git a/common/common.h b/common/common.h index d48ed7900..7772ed63a 100644 --- a/common/common.h +++ b/common/common.h @@ -232,6 +232,14 @@ struct common_params_sampling { std::vector logit_bias; // logit biases to apply std::vector logit_bias_eog; // pre-calculated logit biases for EOG tokens + // reasoning budget sampler parameters + // these are populated by the server/CLI based on chat template params + int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget + bool reasoning_budget_activate_immediately = false; + std::vector reasoning_budget_start; // start tag token sequence + std::vector reasoning_budget_end; // end tag token sequence + std::vector reasoning_budget_forced; // forced sequence (message + end tag) + bool backend_sampling = false; bool has_logit_bias() const { @@ -533,7 +541,9 @@ struct common_params { bool use_jinja = true; // NOLINT bool enable_chat_template = true; common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + int enable_reasoning = -1; // -1 = auto, 0 = disable, 1 = enable int reasoning_budget = -1; + std::string reasoning_budget_message; // message injected before end tag when budget exhausted bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time diff --git a/common/reasoning-budget.cpp b/common/reasoning-budget.cpp new file mode 100644 index 000000000..a55e4f509 --- /dev/null +++ b/common/reasoning-budget.cpp @@ -0,0 +1,219 @@ +#include "reasoning-budget.h" +#include "common.h" +#include "unicode.h" + +#include "log.h" + +#include +#include +#include +#include + +struct token_matcher { + std::vector tokens; + size_t pos = 0; + + bool advance(llama_token token) { + if (tokens.empty()) { + return false; + } + + if (token == tokens[pos]) { + pos++; + if (pos >= tokens.size()) { + pos = 0; + return true; + } + } else { + pos = 0; + if (token == tokens[0]) { + pos = 1; + } + } + return false; + } + + void reset() { pos = 0; } +}; + +struct common_reasoning_budget_ctx { + const llama_vocab * vocab; + + token_matcher start_matcher; + token_matcher end_matcher; + std::vector forced_tokens; + + int32_t budget; // maximum tokens in reasoning block + int32_t remaining; // tokens remaining in budget + + common_reasoning_budget_state state; + + // for forcing + size_t force_pos; // next position in forced_tokens to force +}; + +static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) { + return "reasoning-budget"; +} + +static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; + + switch (ctx->state) { + case REASONING_BUDGET_IDLE: + { + if (ctx->start_matcher.advance(token)) { + ctx->state = REASONING_BUDGET_COUNTING; + ctx->remaining = ctx->budget; + LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget); + + if (ctx->remaining <= 0) { + ctx->state = REASONING_BUDGET_FORCING; + ctx->force_pos = 0; + LOG_INF("reasoning-budget: budget=0, forcing immediately\n"); + } + } + break; + } + case REASONING_BUDGET_COUNTING: + case REASONING_BUDGET_WAITING_UTF8: + { + if (ctx->end_matcher.advance(token)) { + ctx->state = REASONING_BUDGET_DONE; + LOG_INF("reasoning-budget: deactivated (natural end)\n"); + break; + } + + bool utf8_complete = true; + if (ctx->vocab != nullptr) { + const std::string piece = common_token_to_piece(ctx->vocab, token, false); + utf8_complete = common_utf8_is_complete(piece); + } + + if (ctx->state == REASONING_BUDGET_WAITING_UTF8) { + if (utf8_complete) { + ctx->state = REASONING_BUDGET_FORCING; + ctx->force_pos = 0; + ctx->end_matcher.reset(); + LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n"); + } + } else if (ctx->state == REASONING_BUDGET_COUNTING) { + ctx->remaining--; + if (ctx->remaining <= 0) { + if (utf8_complete) { + ctx->state = REASONING_BUDGET_FORCING; + ctx->force_pos = 0; + ctx->end_matcher.reset(); + LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n"); + } else { + ctx->state = REASONING_BUDGET_WAITING_UTF8; + ctx->end_matcher.reset(); + LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n"); + } + } + } + break; + } + case REASONING_BUDGET_FORCING: + // force_pos is advanced in apply(), not here. + // This ensures the first forced token isn't skipped when the sampler + // is initialized directly in FORCING state (e.g. COUNTING + budget=0) + break; + case REASONING_BUDGET_DONE: + break; + } +} + +static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; + + if (ctx->state != REASONING_BUDGET_FORCING) { + // passthrough — don't modify logits + return; + } + + if (ctx->force_pos >= ctx->forced_tokens.size()) { + return; + } + + const llama_token forced = ctx->forced_tokens[ctx->force_pos]; + + // set all logits to -inf except the forced token + for (size_t i = 0; i < cur_p->size; i++) { + if (cur_p->data[i].id != forced) { + cur_p->data[i].logit = -INFINITY; + } + } + + // advance to next forced token (done here rather than in accept so that + // the first forced token isn't skipped when starting in FORCING state) + ctx->force_pos++; + if (ctx->force_pos >= ctx->forced_tokens.size()) { + ctx->state = REASONING_BUDGET_DONE; + LOG_INF("reasoning-budget: forced sequence complete, done\n"); + } +} + +static void common_reasoning_budget_reset(struct llama_sampler * smpl) { + auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; + ctx->state = REASONING_BUDGET_IDLE; + ctx->remaining = ctx->budget; + ctx->start_matcher.reset(); + ctx->end_matcher.reset(); + ctx->force_pos = 0; +} + +static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx; + return common_reasoning_budget_init( + ctx->vocab, + ctx->start_matcher.tokens, + ctx->end_matcher.tokens, + ctx->forced_tokens, + ctx->budget, + ctx->state); +} + +static void common_reasoning_budget_free(struct llama_sampler * smpl) { + delete (common_reasoning_budget_ctx *) smpl->ctx; +} + +static struct llama_sampler_i common_reasoning_budget_i = { + /* .name = */ common_reasoning_budget_name, + /* .accept = */ common_reasoning_budget_accept, + /* .apply = */ common_reasoning_budget_apply, + /* .reset = */ common_reasoning_budget_reset, + /* .clone = */ common_reasoning_budget_clone, + /* .free = */ common_reasoning_budget_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, +}; + +struct llama_sampler * common_reasoning_budget_init( + const struct llama_vocab * vocab, + const std::vector & start_tokens, + const std::vector & end_tokens, + const std::vector & forced_tokens, + int32_t budget, + common_reasoning_budget_state initial_state) { + // promote COUNTING with budget <= 0 to FORCING + if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) { + initial_state = REASONING_BUDGET_FORCING; + } + + return llama_sampler_init( + /* .iface = */ &common_reasoning_budget_i, + /* .ctx = */ new common_reasoning_budget_ctx { + /* .vocab = */ vocab, + /* .start_matcher = */ { start_tokens, 0 }, + /* .end_matcher = */ { end_tokens, 0 }, + /* .forced_tokens = */ forced_tokens, + /* .budget = */ budget, + /* .remaining = */ budget, + /* .state = */ initial_state, + /* .force_pos = */ 0, + } + ); +} diff --git a/common/reasoning-budget.h b/common/reasoning-budget.h new file mode 100644 index 000000000..08ad28248 --- /dev/null +++ b/common/reasoning-budget.h @@ -0,0 +1,41 @@ +#pragma once + +#include "llama.h" + +#include +#include + +enum common_reasoning_budget_state { + REASONING_BUDGET_IDLE, // waiting for start sequence + REASONING_BUDGET_COUNTING, // counting down tokens + REASONING_BUDGET_FORCING, // forcing budget message + end sequence + REASONING_BUDGET_WAITING_UTF8, // budget exhausted, waiting for UTF-8 completion + REASONING_BUDGET_DONE, // passthrough forever +}; + +// Creates a reasoning budget sampler that limits token generation inside a +// reasoning block (e.g. between and ). +// +// State machine: IDLE -> COUNTING -> WAITING_UTF8 -> FORCING -> DONE +// IDLE: passthrough, watching for start_tokens sequence +// COUNTING: counting down remaining tokens, watching for natural end_tokens +// WAITING_UTF8: budget exhausted, allowing tokens to complete a UTF-8 sequence +// FORCING: forces forced_tokens token-by-token (all other logits -> -inf) +// DONE: passthrough forever +// +// Parameters: +// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr) +// start_tokens - token sequence that activates counting +// end_tokens - token sequence for natural deactivation +// forced_tokens - token sequence forced when budget expires +// budget - max tokens allowed in the reasoning block +// initial_state - initial state of the sampler (e.g. IDLE or COUNTING) +// note: COUNTING with budget <= 0 is promoted to FORCING +// +struct llama_sampler * common_reasoning_budget_init( + const struct llama_vocab * vocab, + const std::vector & start_tokens, + const std::vector & end_tokens, + const std::vector & forced_tokens, + int32_t budget, + common_reasoning_budget_state initial_state); diff --git a/common/sampling.cpp b/common/sampling.cpp index 11a1d4839..f849d4f61 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -2,6 +2,7 @@ #include "common.h" #include "log.h" +#include "reasoning-budget.h" #include #include @@ -250,6 +251,17 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st } } + // reasoning budget sampler — added first so it can force tokens before other samplers + if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) { + samplers.push_back(common_reasoning_budget_init( + vocab, + params.reasoning_budget_start, + params.reasoning_budget_end, + params.reasoning_budget_forced, + params.reasoning_budget_tokens, + params.reasoning_budget_activate_immediately ? REASONING_BUDGET_COUNTING : REASONING_BUDGET_IDLE)); + } + if (params.has_logit_bias()) { samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data())); } diff --git a/common/unicode.cpp b/common/unicode.cpp index c0ef6d029..f71fe5678 100644 --- a/common/unicode.cpp +++ b/common/unicode.cpp @@ -1,8 +1,10 @@ #include "unicode.h" + +#include #include #include -#include #include +#include // implementation adopted from src/unicode.cpp @@ -67,6 +69,20 @@ utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t off return utf8_parse_result(utf8_parse_result::INVALID); } +bool common_utf8_is_complete(const std::string & s) { + if (s.empty()) { + return true; + } + for (int i = 1; i <= std::min(4, (int)s.size()); i++) { + unsigned char c = s[s.size() - i]; + if ((c & 0xC0) != 0x80) { + int expected = (c >= 0xF0) ? 4 : (c >= 0xE0) ? 3 : (c >= 0xC0) ? 2 : 1; + return i >= expected; + } + } + return false; +} + std::string common_unicode_cpts_to_utf8(const std::vector & cps) { std::string result; for (size_t i = 0; i < cps.size(); ++i) { diff --git a/common/unicode.h b/common/unicode.h index 87bcc0ffc..9b32fa19d 100644 --- a/common/unicode.h +++ b/common/unicode.h @@ -20,6 +20,9 @@ struct utf8_parse_result { // Returns 0 for invalid first bytes size_t common_utf8_sequence_length(unsigned char first_byte); +// Check if a string ends with a complete UTF-8 sequence. +bool common_utf8_is_complete(const std::string & s); + // Parse a single UTF-8 codepoint from input utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t offset); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 083b5bca9..4a4aac41d 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4390,15 +4390,31 @@ class Qwen3Model(Qwen2Model): hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False) self.origin_hf_arch = hparams.get('architectures', [None])[0] - # a bit hacky, but currently the only way to detect if this is a rerank model - # ref: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B + if self._is_qwen3_reranker(): + self._find_rerank_config() + + def _is_qwen3_reranker(self) -> bool: readme_path = self.dir_model / "README.md" readme_text = "" if readme_path.exists(): with readme_path.open("r", encoding="utf-8") as f: readme_text = f.read() - if "# Qwen3-Reranker" in readme_text: - self._find_rerank_config() + + name_hints = [ + str(self.dir_model.name), + str(self.hparams.get("_name_or_path", "")), + str(self.hparams.get("model_type", "")), + str(self.origin_hf_arch or ""), + ] + name_hints = [hint.lower() for hint in name_hints if hint] + + if "# qwen3-reranker" in readme_text.lower() or "# qwen3-vl-reranker" in readme_text.lower(): + return True + + if any("qwen3-reranker" in hint or "qwen3-vl-reranker" in hint for hint in name_hints): + return True + + return "sequenceclassification" in (self.origin_hf_arch or "").lower() def set_vocab(self): # deal with intern-s1-mini @@ -9727,20 +9743,35 @@ class NemotronHModel(GraniteHybridModel): # M: Mamba2, *: Attention, -: MLP # MoE: # M: Mamba2, *: Attention, E: Expert - hybrid_override_pattern = self.hparams["hybrid_override_pattern"] - self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"] - self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == ("E" if self.is_moe else "-")] + pattern = self.hparams.get("hybrid_override_pattern") or self.hparams.get("layers_block_type") + if pattern is None: + self._ssm_layers = [] + self._mlp_layers = [] + elif isinstance(pattern, str): + self._ssm_layers = [i for i, val in enumerate(pattern) if val == "M"] + self._mlp_layers = [i for i, val in enumerate(pattern) if val == ("E" if self.is_moe else "-")] + else: + self._ssm_layers = [i for i, val in enumerate(pattern) if val == "mamba"] + self._mlp_layers = [i for i, val in enumerate(pattern) if val == "moe"] def get_attn_layers(self): - hybrid_override_pattern = self.hparams["hybrid_override_pattern"] - assert len(hybrid_override_pattern) == self.block_count, "Mismatch between hybrid override and num_hidden_layers!" - return [i for i, val in enumerate(hybrid_override_pattern) if val == "*"] + pattern = self.hparams.get("hybrid_override_pattern") or self.hparams.get("layers_block_type") + if pattern is None: + return [] + assert len(pattern) == self.block_count, f"Mismatch between pattern ({len(pattern)}) and block_count ({self.block_count})!" + if isinstance(pattern, str): + return [i for i, val in enumerate(pattern) if val == "*"] + + return [i for i, val in enumerate(pattern) if val == "attention"] def set_gguf_parameters(self): super().set_gguf_parameters() - self.gguf_writer.add_key_length(self.head_dim) - self.gguf_writer.add_value_length(self.head_dim) + head_dim = self.head_dim + if head_dim is None: + raise ValueError("Could not find the attention head dim in config") + self.gguf_writer.add_key_length(head_dim) + self.gguf_writer.add_value_length(head_dim) # Set feed_forward_length # NOTE: This will trigger an override warning. This is preferable to @@ -9768,6 +9799,9 @@ class NemotronHModel(GraniteHybridModel): if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: self.gguf_writer.add_expert_used_count(n_experts_used) + if (latent_size := self.hparams.get("moe_latent_size")) is not None: + self.gguf_writer.add_moe_latent_size(latent_size) + def set_vocab(self): super().set_vocab() @@ -9787,6 +9821,13 @@ class NemotronHModel(GraniteHybridModel): name = name[len("language_model."):] if self.is_moe and bid is not None: + # Skip Multi-Token Prediction (MTP) tensors. These are used for + # for speculative decoding but we don't include them in this model + # conversion. See https://github.com/ggml-org/llama.cpp/pull/18886 + if "mtp" in name: + logger.info(f"gguf: Skipping MTP (Speculative) layer: {name}") + return [] + if name.endswith("mixer.gate.e_score_correction_bias"): new_name = name.replace("e_score_correction_bias", "e_score_correction.bias") yield from ModelBase.modify_tensors(self, data_torch, new_name, bid) diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index df1ad2a51..1c11495b6 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -8,7 +8,12 @@ extern "C" { #define RPC_PROTO_MAJOR_VERSION 3 #define RPC_PROTO_MINOR_VERSION 6 -#define RPC_PROTO_PATCH_VERSION 0 +#define RPC_PROTO_PATCH_VERSION 1 + +#ifdef __cplusplus +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION"); +#endif + #define GGML_RPC_MAX_SERVERS 16 // backend API diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index d8e811145..c249bbc86 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -2,28 +2,29 @@ #include "ggml-cuda/common.cuh" template -__global__ void gated_delta_net_cuda(const float * q, - const float * k, - const float * v, - const float * g, - const float * beta, - const float * curr_state, - float * dst, - int64_t H, - int64_t n_tokens, - int64_t n_seqs, - int64_t sq1, - int64_t sq2, - int64_t sq3, - int64_t sv1, - int64_t sv2, - int64_t sv3, - int64_t sb1, - int64_t sb2, - int64_t sb3, - int64_t rq1, - int64_t rq3, - float scale) { +__global__ void __launch_bounds__(S_v, 1) +gated_delta_net_cuda(const float * q, + const float * k, + const float * v, + const float * g, + const float * beta, + const float * curr_state, + float * dst, + const int64_t H, + const int64_t n_tokens, + const int64_t n_seqs, + const int64_t sq1, + const int64_t sq2, + const int64_t sq3, + const int64_t sv1, + const int64_t sv2, + const int64_t sv3, + const int64_t sb1, + const int64_t sb2, + const int64_t sb3, + const int64_t rq1, + const int64_t rq3, + const float scale) { const int64_t h_idx = blockIdx.x; const int64_t sequence = blockIdx.y; const int col = threadIdx.x; // each thread owns one column @@ -40,8 +41,14 @@ __global__ void gated_delta_net_cuda(const float * q, curr_state += state_offset; attn_data += (sequence * n_tokens * H + h_idx) * S_v; - // Load state column into registers + // GCN and CDNA devices spill registers, we use shared mem for them. See https://github.com/ggml-org/llama.cpp/pull/20282#issuecomment-4025770229 + // TODO: check optimal path for RDNA1 and RDNA2 devices. +#if (defined(GGML_USE_HIP) && !defined(RDNA3) && !defined(RDNA4)) || defined(GGML_USE_MUSA) + extern __shared__ float s_shared[]; + float * s = s_shared + col * S_v; +#else float s[S_v]; +#endif #pragma unroll for (int i = 0; i < S_v; i++) { s[i] = curr_state[i * S_v + col]; @@ -114,6 +121,15 @@ __global__ void gated_delta_net_cuda(const float * q, } } +static size_t calculate_smem(const int sv, int cc) +{ + size_t smem = 0; + if ((GGML_CUDA_CC_IS_AMD(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_RDNA4(cc)) || GGML_CUDA_CC_IS_MTHREADS(cc)) { + smem = sv * sv * sizeof(float); + } + return smem; +} + template static void launch_gated_delta_net( const float * q_d, const float * k_d, const float * v_d, @@ -129,25 +145,36 @@ static void launch_gated_delta_net( dim3 grid_dims(H, n_seqs, 1); dim3 block_dims(S_v, 1, 1); + int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + switch (S_v) { - case 32: - gated_delta_net_cuda<32, KDA><<>>( + case 32: { + constexpr int sv = 32; + size_t smem = calculate_smem(sv, cc); + gated_delta_net_cuda<<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, rq1, rq3, scale); break; - case 64: - gated_delta_net_cuda<64, KDA><<>>( + } + case 64: { + constexpr int sv = 64; + size_t smem = calculate_smem(sv, cc); + gated_delta_net_cuda<<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, rq1, rq3, scale); break; - case 128: - gated_delta_net_cuda<128, KDA><<>>( + } + case 128: { + constexpr int sv = 128; + size_t smem = calculate_smem(sv, cc); + gated_delta_net_cuda<<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, rq1, rq3, scale); break; + } default: GGML_ABORT("fatal error"); break; diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 85e82b5a4..69985cd33 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -76,7 +76,7 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, int row = tid / load_cols; int col = tid % load_cols; #pragma unroll - for (int idx = tid; idx < total_elems; idx += split_d_inner) { + for (int idx = 0; idx < total_elems; idx += split_d_inner) { if (row < (int)split_d_inner) { smem[row * n_cols + col] = x_block[row * stride_x + col]; } @@ -84,6 +84,9 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, col += split_d_inner; row += col / load_cols; col = col % load_cols; + if (idx >= total_elems - tid - split_d_inner) { + break; + } } __syncthreads(); diff --git a/ggml/src/ggml-metal/ggml-metal-context.m b/ggml/src/ggml-metal/ggml-metal-context.m index 1136ce99b..855fd1ada 100644 --- a/ggml/src/ggml-metal/ggml-metal-context.m +++ b/ggml/src/ggml-metal/ggml-metal-context.m @@ -47,7 +47,7 @@ struct ggml_metal { uint64_t fuse_cnt[GGML_OP_COUNT]; // capture state - bool capture_next_compute; + int capture_compute; bool capture_started; id capture_scope; @@ -158,10 +158,17 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) { GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, res->use_concurrency ? "true" : "false"); GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, res->use_graph_optimize ? "true" : "false"); - res->capture_next_compute = false; + res->capture_compute = 0; res->capture_started = false; res->capture_scope = nil; + { + const char * val = getenv("GGML_METAL_CAPTURE_COMPUTE"); + if (val) { + res->capture_compute = atoi(val); + } + } + res->has_error = false; res->gf = nil; @@ -458,9 +465,13 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb; - const bool use_capture = ctx->capture_next_compute; + if (ctx->capture_compute >= 0) { + ctx->capture_compute--; + } + + const bool use_capture = ctx->capture_compute == 0; if (use_capture) { - ctx->capture_next_compute = false; + ctx->capture_compute = -1; // make sure all previous computations have finished before starting the capture if (ctx->cmd_buf_last) { @@ -469,6 +480,10 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * } if (!ctx->capture_started) { + NSString * path = [NSString stringWithFormat:@"/tmp/perf-metal-%d.gputrace", getpid()]; + + GGML_LOG_WARN("%s: capturing graph in %s\n", __func__, [path UTF8String]); + // create capture scope id device = ggml_metal_device_get_obj(ctx->dev); ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:device]; @@ -476,7 +491,7 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; descriptor.captureObject = ctx->capture_scope; descriptor.destination = MTLCaptureDestinationGPUTraceDocument; - descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]]; + descriptor.outputURL = [NSURL fileURLWithPath:path]; NSError * error = nil; if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { @@ -683,7 +698,7 @@ void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) { idx_end, ctx->use_fusion, ctx->use_concurrency, - ctx->capture_next_compute, + ctx->capture_compute, ctx->debug_graph, ctx->debug_fusion); @@ -718,5 +733,5 @@ bool ggml_metal_supports_family(ggml_metal_t ctx, int family) { } void ggml_metal_capture_next_compute(ggml_metal_t ctx) { - ctx->capture_next_compute = true; + ctx->capture_compute = 1; } diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index bf51055e3..99d64efc3 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -35,7 +35,7 @@ #define N_R0_Q4_K 2 #define N_SG_Q4_K 2 -#define N_R0_Q5_K 2 +#define N_R0_Q5_K 1 #define N_SG_Q5_K 2 #define N_R0_Q6_K 2 diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 82ebbb4e4..29e4a245d 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -9081,6 +9081,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_ template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>; template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>; template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>; +template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>; template kernel void kernel_mul_mm_id( diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c5f546950..32fc9428a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -125,6 +125,7 @@ class Keys: EXPERT_GROUP_SCALE = "{arch}.expert_group_scale" EXPERTS_PER_GROUP = "{arch}.experts_per_group" MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers" + MOE_LATENT_SIZE = "{arch}.moe_latent_size" NEXTN_PREDICT_LAYERS = "{arch}.nextn_predict_layers" NUM_DEEPSTACK_LAYERS = "{arch}.n_deepstack_layers" POOLING_TYPE = "{arch}.pooling_type" @@ -543,6 +544,8 @@ class MODEL_TENSOR(IntEnum): FFN_DOWN_CHEXP = auto() FFN_UP_CHEXP = auto() FFN_EXP_PROBS_B = auto() + MOE_LATENT_DOWN = auto() # nemotron 3 super + MOE_LATENT_UP = auto() # nemotron 3 super ATTN_Q_NORM = auto() ATTN_K_NORM = auto() LAYER_OUT_NORM = auto() @@ -986,6 +989,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", MODEL_TENSOR.FFN_GATE_UP_EXP: "blk.{bid}.ffn_gate_up_exps", MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", + MODEL_TENSOR.MOE_LATENT_DOWN: "blk.{bid}.ffn_latent_down", # nemotron 3 super + MODEL_TENSOR.MOE_LATENT_UP: "blk.{bid}.ffn_latent_up", # nemotron 3 super MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n MODEL_TENSOR.PER_LAYER_MODEL_PROJ: "per_layer_model_proj", # gemma3n @@ -2913,6 +2918,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_GATE_INP, MODEL_TENSOR.FFN_UP_EXP, MODEL_TENSOR.FFN_DOWN_EXP, + # expert latent + MODEL_TENSOR.MOE_LATENT_DOWN, + MODEL_TENSOR.MOE_LATENT_UP, # shared expert MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index e790be953..c89a5fdc3 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -859,6 +859,9 @@ class GGUFWriter: def add_moe_every_n_layers(self, value: int) -> None: self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value) + def add_moe_latent_size(self, value: int) -> None: + self.add_uint32(Keys.LLM.MOE_LATENT_SIZE.format(arch=self.arch), value) + def add_nextn_predict_layers(self, count: int) -> None: self.add_uint32(Keys.LLM.NEXTN_PREDICT_LAYERS.format(arch=self.arch), count) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index e57561090..18131e540 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -571,6 +571,14 @@ class TensorNameMap: "model.layers.{bid}.mlp.experts.gate_up_proj", ), + MODEL_TENSOR.MOE_LATENT_DOWN: ( + "backbone.layers.{bid}.mixer.fc1_latent_proj", # nemotron 3 super + ), + + MODEL_TENSOR.MOE_LATENT_UP: ( + "backbone.layers.{bid}.mixer.fc2_latent_proj", # nemotron 3 super + ), + # Feed-forward down MODEL_TENSOR.FFN_DOWN: ( "gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index ce49bbd98..799d16167 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -185,6 +185,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_GROUP_SCALE, "%s.expert_group_scale" }, { LLM_KV_EXPERTS_PER_GROUP, "%s.experts_per_group" }, { LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" }, + { LLM_KV_MOE_LATENT_SIZE, "%s.moe_latent_size" }, { LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" }, { LLM_KV_NUM_DEEPSTACK_LAYERS, "%s.n_deepstack_layers" }, { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, @@ -365,6 +366,8 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + { LLM_TENSOR_FFN_LATENT_DOWN, "blk.%d.ffn_latent_down" }, + { LLM_TENSOR_FFN_LATENT_UP, "blk.%d.ffn_latent_up" }, { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, @@ -1087,6 +1090,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_TOKEN_EMBD, LLM_TENSOR_OUTPUT_NORM, LLM_TENSOR_OUTPUT, + LLM_TENSOR_CLS_OUT, LLM_TENSOR_ATTN_NORM, LLM_TENSOR_ATTN_Q, LLM_TENSOR_ATTN_Q_NORM, @@ -1878,6 +1882,8 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_UP_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_FFN_LATENT_DOWN, + LLM_TENSOR_FFN_LATENT_UP, // MoE shared expert layer LLM_TENSOR_FFN_DOWN_SHEXP, LLM_TENSOR_FFN_UP_SHEXP, @@ -2753,6 +2759,9 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + // Nemotron 3 Super + {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-arch.h b/src/llama-arch.h index 28dd1ffac..b1b1dcf18 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -189,6 +189,7 @@ enum llm_kv { LLM_KV_EXPERT_GROUP_SCALE, LLM_KV_EXPERTS_PER_GROUP, LLM_KV_MOE_EVERY_N_LAYERS, + LLM_KV_MOE_LATENT_SIZE, LLM_KV_NEXTN_PREDICT_LAYERS, LLM_KV_NUM_DEEPSTACK_LAYERS, LLM_KV_POOLING_TYPE, @@ -385,6 +386,8 @@ enum llm_tensor { LLM_TENSOR_FFN_GATE_CHEXPS, LLM_TENSOR_FFN_UP_CHEXPS, LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_FFN_LATENT_DOWN, + LLM_TENSOR_FFN_LATENT_UP, LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_LAYER_OUT_NORM, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 5f875136a..528f8e545 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -250,7 +250,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { const bool last = ( cparams.pooling_type == LLAMA_POOLING_TYPE_LAST || - (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token + (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL)) // qwen3 reranking & embedding models use last token ); for (int i = 0; i < n_tokens; ++i) { @@ -2552,7 +2552,7 @@ void llm_graph_context::build_pooling( } // softmax for qwen3 reranker - if (arch == LLM_ARCH_QWEN3) { + if (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL) { cur = ggml_soft_max(ctx0, cur); } } break; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index abfd7f2c4..78c0bc27d 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -89,6 +89,7 @@ struct llama_hparams { bool expert_weights_norm = false; uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE; uint32_t moe_every_n_layers = 0; + uint32_t moe_latent_size = 0; uint32_t nextn_predict_layers = 0; float f_norm_eps; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index fc74311bf..acfbfe944 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -249,6 +249,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_100B_A6B: return "100B.A6B"; case LLM_TYPE_102B_A12B: return "102B.A12B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; + case LLM_TYPE_120B_A12B: return "120B.A12B"; case LLM_TYPE_122B_A10B: return "122B.A10B"; case LLM_TYPE_196B_A11B: return "196B.A11B"; case LLM_TYPE_230B_A10B: return "230B.A10B"; @@ -1975,10 +1976,12 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_MOE_LATENT_SIZE, hparams.moe_latent_size, false); switch (hparams.n_layer) { case 52: type = LLM_TYPE_31B_A3_5B; break; // Nemotron-H_MOE 31B case 56: type = LLM_TYPE_9B; break; + case 88: type = LLM_TYPE_120B_A12B; break; default: type = LLM_TYPE_UNKNOWN; } } break; @@ -5702,6 +5705,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_ssm_head = hparams.ssm_dt_rank; const int64_t n_group = hparams.ssm_n_group; const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; + const int64_t moe_n_embd = hparams.moe_latent_size > 0 ? hparams.moe_latent_size : n_embd; // embeddings tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5761,8 +5765,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0); // MoE branch - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_latent_down = create_tensor(tn(LLM_TENSOR_FFN_LATENT_DOWN, "weight", i), {n_embd, moe_n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_latent_up = create_tensor(tn(LLM_TENSOR_FFN_LATENT_UP, "weight", i), {moe_n_embd, n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, moe_n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {moe_n_embd, n_ff_exp, n_expert}, 0); // Shared expert branch layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); diff --git a/src/llama-model.h b/src/llama-model.h index 5ecb8344a..74c79a774 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -126,6 +126,7 @@ enum llm_type { LLM_TYPE_100B_A6B, LLM_TYPE_102B_A12B, // Solar-Open LLM_TYPE_106B_A12B, // GLM-4.5-Air + LLM_TYPE_120B_A12B, // Nemotron 3 Super LLM_TYPE_122B_A10B, // Qwen3.5 LLM_TYPE_196B_A11B, // Step3.5-Flash LLM_TYPE_230B_A10B, // Minimax M2 @@ -294,6 +295,10 @@ struct llama_layer { struct ggml_tensor * ffn_up_exps_b = nullptr; struct ggml_tensor * ffn_gate_up_exps_b = nullptr; + // ff MoE latent proj + struct ggml_tensor * ffn_latent_down = nullptr; + struct ggml_tensor * ffn_latent_up = nullptr; + // ff shared expert (shexp) struct ggml_tensor * ffn_gate_inp_shexp = nullptr; struct ggml_tensor * ffn_gate_shexp = nullptr; diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index c5bc55f7d..e6f0c4fca 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -872,9 +872,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: quantize_state_impl qs(model, params); - // these need to be set to n_layer by default - qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer; - if (params->only_copy) { ftype = ml.ftype; } @@ -981,6 +978,22 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // compute tensor metadata once and cache it std::vector metadata(tensors.size()); + // initialize quantization state before preliminary loop (counters for use_more_bits) + { + for (size_t i = 0; i < tensors.size(); ++i) { + const auto cat = tensor_get_category(tensors[i]->tensor->name); + if (category_is_attn_v(cat)) { + ++qs.n_attention_wv; + } + if (cat == tensor_category::OUTPUT) { + qs.has_tied_embeddings = false; + } + metadata[i].category = cat; // save and re-use the category while we're at it + } + // these also need to be set to n_layer by default + qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer; + } + // flag for --dry-run bool will_require_imatrix = false; @@ -993,16 +1006,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: const struct ggml_tensor * tensor = it->tensor; const std::string name = ggml_get_name(tensor); - metadata[i].category = tensor_get_category(name); - - if (category_is_attn_v(metadata[i].category)) { - ++qs.n_attention_wv; - } - - if (tensor_name_match_output_weight(name.c_str())) { - qs.has_tied_embeddings = false; - } - uint16_t i_split = params->keep_split ? it->idx : 0; if (!ctx_outs[i_split]) { ctx_outs[i_split].reset(gguf_init_empty()); diff --git a/src/models/nemotron-h.cpp b/src/models/nemotron-h.cpp index 635821505..7af99174d 100644 --- a/src/models/nemotron-h.cpp +++ b/src/models/nemotron-h.cpp @@ -114,9 +114,18 @@ ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const lla LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); } else { - ggml_tensor * ffn_inp = cur; + ggml_tensor * inp_emb = cur; + ggml_tensor * inp_latent = cur; + + if (model.layers[il].ffn_latent_down) { + inp_latent = ggml_mul_mat(ctx0, model.layers[il].ffn_latent_down, cur); + } + + ggml_tensor * router_logits = build_lora_mm(model.layers[il].ffn_gate_inp, cur); + cb(router_logits, "ffn_moe_logits", il); + ggml_tensor * moe_out = - build_moe_ffn(ffn_inp, + build_moe_ffn(inp_latent, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, nullptr, // no gate @@ -126,10 +135,15 @@ ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const lla LLM_FFN_RELU_SQR, hparams.expert_weights_norm, hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, - il); + il, + router_logits); cb(moe_out, "ffn_moe_out", il); - ggml_tensor * ffn_shexp = build_ffn(ffn_inp, + if (model.layers[il].ffn_latent_up) { + moe_out = ggml_mul_mat(ctx0, model.layers[il].ffn_latent_up, moe_out); + } + + ggml_tensor * ffn_shexp = build_ffn(inp_emb, model.layers[il].ffn_up_shexp, NULL, NULL, NULL /* no gate */ , NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index ed3fc127b..3d0991dde 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 5b8895b34..bd203228c 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -1101,6 +1101,22 @@ json oaicompat_chat_params_parse( llama_params["chat_parser"] = chat_params.parser; } + // Reasoning budget: pass parameters through to sampling layer + { + int reasoning_budget = opt.reasoning_budget; + if (reasoning_budget == -1 && body.contains("thinking_budget_tokens")) { + reasoning_budget = json_value(body, "thinking_budget_tokens", -1); + } + + if (reasoning_budget >= 0 && !chat_params.thinking_end_tag.empty()) { + llama_params["reasoning_budget_tokens"] = reasoning_budget; + llama_params["reasoning_budget_start_tag"] = chat_params.thinking_start_tag; + llama_params["reasoning_budget_end_tag"] = chat_params.thinking_end_tag; + llama_params["reasoning_budget_message"] = opt.reasoning_budget_message; + llama_params["reasoning_budget_activate_immediately"] = chat_params.thinking_forced_open; + } + } + // Handle "logprobs" field // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future if (json_value(body, "logprobs", false)) { diff --git a/tools/server/server-common.h b/tools/server/server-common.h index a234541e1..3e56b3d85 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -287,6 +287,8 @@ struct server_chat_params { bool allow_image; bool allow_audio; bool enable_thinking = true; + int reasoning_budget = -1; + std::string reasoning_budget_message; std::string media_path; }; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index b86e7e608..b4373c101 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -893,9 +893,10 @@ private: } // thinking is enabled if: - // 1. It's not explicitly disabled (reasoning_budget == 0) + // 1. It's not explicitly disabled via --reasoning off // 2. The chat template supports it - const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get()); + const bool template_supports_thinking = params_base.use_jinja && common_chat_templates_support_enable_thinking(chat_templates.get()); + const bool enable_thinking = params_base.enable_reasoning != 0 && template_supports_thinking; SRV_INF("%s: chat template, thinking = %d\n", __func__, enable_thinking); chat_params = { @@ -907,6 +908,8 @@ private: /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, /* enable_thinking */ enable_thinking, + /* reasoning_budget */ params_base.reasoning_budget, + /* reasoning_budget_msg */ params_base.reasoning_budget_message, /* media_path */ params_base.media_path, }; } diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 9d6e422d6..b3d510977 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -462,6 +462,34 @@ task_params server_task::params_from_json_cmpl( } } + // Parse reasoning budget sampler parameters + { + const int32_t budget = json_value(data, "reasoning_budget_tokens", (int32_t) -1); + if (budget >= 0) { + const auto start_tag = json_value(data, "reasoning_budget_start_tag", std::string()); + const auto end_tag = json_value(data, "reasoning_budget_end_tag", std::string()); + const auto message = json_value(data, "reasoning_budget_message", std::string()); + const bool activate_imm = json_value(data, "reasoning_budget_activate_immediately", false); + + params.sampling.reasoning_budget_tokens = budget; + params.sampling.reasoning_budget_activate_immediately = activate_imm; + + if (!start_tag.empty()) { + params.sampling.reasoning_budget_start = common_tokenize(vocab, start_tag, false, true); + } + if (!end_tag.empty()) { + params.sampling.reasoning_budget_end = common_tokenize(vocab, end_tag, false, true); + params.sampling.reasoning_budget_forced = common_tokenize(vocab, message + end_tag, false, true); + } + + SRV_DBG("reasoning budget: tokens=%d, activate_immediately=%s, start=%zu toks, end=%zu toks, forced=%zu toks\n", + budget, activate_imm ? "true" : "false", + params.sampling.reasoning_budget_start.size(), + params.sampling.reasoning_budget_end.size(), + params.sampling.reasoning_budget_forced.size()); + } + } + { params.sampling.logit_bias.clear(); diff --git a/tools/server/webui/src/lib/stores/agentic.svelte.ts b/tools/server/webui/src/lib/stores/agentic.svelte.ts index a6dd8581e..f8834f9df 100644 --- a/tools/server/webui/src/lib/stores/agentic.svelte.ts +++ b/tools/server/webui/src/lib/stores/agentic.svelte.ts @@ -318,6 +318,12 @@ class AgenticStore { const maxTurns = agenticConfig.maxTurns; const maxToolPreviewLines = agenticConfig.maxToolPreviewLines; + // Resolve effective model for vision capability checks. + // In ROUTER mode, options.model is always set by the caller. + // In MODEL mode, options.model is undefined; use the single loaded model + // which carries modalities bridged from /props. + const effectiveModel = options.model || modelsStore.models[0]?.model || ''; + for (let turn = 0; turn < maxTurns; turn++) { this.updateSession(conversationId, { currentTurn: turn + 1 }); agenticTimings.turns = turn + 1; @@ -571,14 +577,14 @@ class AgenticStore { ]; for (const attachment of attachments) { if (attachment.type === AttachmentType.IMAGE) { - if (modelsStore.modelSupportsVision(options.model ?? '')) { + if (modelsStore.modelSupportsVision(effectiveModel)) { contentParts.push({ type: ContentPartType.IMAGE_URL, image_url: { url: (attachment as DatabaseMessageExtraImageFile).base64Url } }); } else { console.info( - `[AgenticStore] Skipping image attachment (model "${options.model}" does not support vision)` + `[AgenticStore] Skipping image attachment (model "${effectiveModel}" does not support vision)` ); } } diff --git a/vendor/cpp-httplib/httplib.cpp b/vendor/cpp-httplib/httplib.cpp index 7f76978fd..c8f88d87d 100644 --- a/vendor/cpp-httplib/httplib.cpp +++ b/vendor/cpp-httplib/httplib.cpp @@ -813,17 +813,13 @@ bool is_websocket_upgrade(const Request &req) { // Check Upgrade: websocket (case-insensitive) auto upgrade_it = req.headers.find("Upgrade"); if (upgrade_it == req.headers.end()) { return false; } - auto upgrade_val = upgrade_it->second; - std::transform(upgrade_val.begin(), upgrade_val.end(), upgrade_val.begin(), - ::tolower); + auto upgrade_val = case_ignore::to_lower(upgrade_it->second); if (upgrade_val != "websocket") { return false; } // Check Connection header contains "Upgrade" auto connection_it = req.headers.find("Connection"); if (connection_it == req.headers.end()) { return false; } - auto connection_val = connection_it->second; - std::transform(connection_val.begin(), connection_val.end(), - connection_val.begin(), ::tolower); + auto connection_val = case_ignore::to_lower(connection_it->second); if (connection_val.find("upgrade") == std::string::npos) { return false; } // Check Sec-WebSocket-Key is a valid base64-encoded 16-byte value (24 chars) @@ -2615,10 +2611,15 @@ bool can_compress_content_type(const std::string &content_type) { switch (tag) { case "image/svg+xml"_t: case "application/javascript"_t: + case "application/x-javascript"_t: case "application/json"_t: + case "application/ld+json"_t: case "application/xml"_t: - case "application/protobuf"_t: - case "application/xhtml+xml"_t: return true; + case "application/xhtml+xml"_t: + case "application/rss+xml"_t: + case "application/atom+xml"_t: + case "application/xslt+xml"_t: + case "application/protobuf"_t: return true; case "text/event-stream"_t: return false; @@ -3038,17 +3039,13 @@ bool read_websocket_upgrade_response(Stream &strm, // Verify Upgrade: websocket (case-insensitive) auto upgrade_it = headers.find("Upgrade"); if (upgrade_it == headers.end()) { return false; } - auto upgrade_val = upgrade_it->second; - std::transform(upgrade_val.begin(), upgrade_val.end(), upgrade_val.begin(), - ::tolower); + auto upgrade_val = case_ignore::to_lower(upgrade_it->second); if (upgrade_val != "websocket") { return false; } // Verify Connection header contains "Upgrade" (case-insensitive) auto connection_it = headers.find("Connection"); if (connection_it == headers.end()) { return false; } - auto connection_val = connection_it->second; - std::transform(connection_val.begin(), connection_val.end(), - connection_val.begin(), ::tolower); + auto connection_val = case_ignore::to_lower(connection_it->second); if (connection_val.find("upgrade") == std::string::npos) { return false; } // Verify Sec-WebSocket-Accept header value @@ -3934,14 +3931,10 @@ public: file_.content_type = trim_copy(header.substr(str_len(header_content_type))); } else { - thread_local const std::regex re_content_disposition( - R"~(^Content-Disposition:\s*form-data;\s*(.*)$)~", - std::regex_constants::icase); - - std::smatch m; - if (std::regex_match(header, m, re_content_disposition)) { + std::string disposition_params; + if (parse_content_disposition(header, disposition_params)) { Params params; - parse_disposition_params(m[1], params); + parse_disposition_params(disposition_params, params); auto it = params.find("name"); if (it != params.end()) { @@ -3956,13 +3949,14 @@ public: it = params.find("filename*"); if (it != params.end()) { - // Only allow UTF-8 encoding... - thread_local const std::regex re_rfc5987_encoding( - R"~(^UTF-8''(.+?)$)~", std::regex_constants::icase); - - std::smatch m2; - if (std::regex_match(it->second, m2, re_rfc5987_encoding)) { - file_.filename = decode_path_component(m2[1]); // override... + // RFC 5987: only UTF-8 encoding is allowed + const auto &val = it->second; + constexpr const char utf8_prefix[] = "UTF-8''"; + constexpr size_t prefix_len = str_len(utf8_prefix); + if (val.size() > prefix_len && + start_with_case_ignore(val, utf8_prefix)) { + file_.filename = decode_path_component( + val.substr(prefix_len)); // override... } else { is_valid_ = false; return false; @@ -4030,17 +4024,48 @@ private: file_.headers.clear(); } - bool start_with_case_ignore(const std::string &a, const char *b) const { + bool start_with_case_ignore(const std::string &a, const char *b, + size_t offset = 0) const { const auto b_len = strlen(b); - if (a.size() < b_len) { return false; } + if (a.size() < offset + b_len) { return false; } for (size_t i = 0; i < b_len; i++) { - if (case_ignore::to_lower(a[i]) != case_ignore::to_lower(b[i])) { + if (case_ignore::to_lower(a[offset + i]) != case_ignore::to_lower(b[i])) { return false; } } return true; } + // Parses "Content-Disposition: form-data; " without std::regex. + // Returns true if header matches, with the params portion in `params_out`. + bool parse_content_disposition(const std::string &header, + std::string ¶ms_out) const { + constexpr const char prefix[] = "Content-Disposition:"; + constexpr size_t prefix_len = str_len(prefix); + + if (!start_with_case_ignore(header, prefix)) { return false; } + + // Skip whitespace after "Content-Disposition:" + auto pos = prefix_len; + while (pos < header.size() && (header[pos] == ' ' || header[pos] == '\t')) { + pos++; + } + + // Match "form-data;" (case-insensitive) + constexpr const char form_data[] = "form-data;"; + constexpr size_t form_data_len = str_len(form_data); + if (!start_with_case_ignore(header, form_data, pos)) { return false; } + pos += form_data_len; + + // Skip whitespace after "form-data;" + while (pos < header.size() && (header[pos] == ' ' || header[pos] == '\t')) { + pos++; + } + + params_out = header.substr(pos); + return true; + } + const std::string dash_ = "--"; const std::string crlf_ = "\r\n"; std::string boundary_; @@ -4992,9 +5017,10 @@ bool match_hostname(const std::string &pattern, // Verify certificate using Windows CertGetCertificateChain API. // This provides real-time certificate validation with Windows Update // integration, independent of the TLS backend (OpenSSL or MbedTLS). -bool verify_cert_with_windows_schannel( - const std::vector &der_cert, const std::string &hostname, - bool verify_hostname, unsigned long &out_error) { +bool +verify_cert_with_windows_schannel(const std::vector &der_cert, + const std::string &hostname, + bool verify_hostname, uint64_t &out_error) { if (der_cert.empty()) { return false; } out_error = 0; @@ -7987,7 +8013,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr, #else try { routed = routing(req, res, strm); - } catch (std::exception &e) { + } catch (std::exception &) { if (exception_handler_) { auto ep = std::current_exception(); exception_handler_(req, res, ep); @@ -11811,7 +11837,7 @@ bool SSLClient::initialize_ssl(Socket &socket, Error &error) { server_certificate_verification_) { verify_result_ = tls::get_verify_result(session); if (verify_result_ != 0) { - last_backend_error_ = static_cast(verify_result_); + last_backend_error_ = static_cast(verify_result_); error = Error::SSLServerVerification; output_error_log(error, nullptr); return false; @@ -11850,7 +11876,7 @@ bool SSLClient::initialize_ssl(Socket &socket, Error &error) { ca_cert_dir_path_.empty() && ca_cert_pem_.empty()) { std::vector der; if (get_cert_der(server_cert, der)) { - unsigned long wincrypt_error = 0; + uint64_t wincrypt_error = 0; if (!detail::verify_cert_with_windows_schannel( der, host_, server_hostname_verification_, wincrypt_error)) { last_backend_error_ = wincrypt_error; @@ -11974,16 +12000,26 @@ bool is_ipv4_address(const std::string &str) { // Parse IPv4 address string to bytes bool parse_ipv4(const std::string &str, unsigned char *out) { - int parts[4]; - if (sscanf(str.c_str(), "%d.%d.%d.%d", &parts[0], &parts[1], &parts[2], - &parts[3]) != 4) { - return false; - } + const char *p = str.c_str(); for (int i = 0; i < 4; i++) { - if (parts[i] < 0 || parts[i] > 255) return false; - out[i] = static_cast(parts[i]); + if (i > 0) { + if (*p != '.') { return false; } + p++; + } + int val = 0; + int digits = 0; + while (*p >= '0' && *p <= '9') { + val = val * 10 + (*p - '0'); + if (val > 255) { return false; } + p++; + digits++; + } + if (digits == 0) { return false; } + // Reject leading zeros (e.g., "01.002.03.04") to prevent ambiguity + if (digits > 1 && *(p - digits) == '0') { return false; } + out[i] = static_cast(val); } - return true; + return *p == '\0'; } #ifdef _WIN32 @@ -13285,11 +13321,11 @@ void update_server_certs_from_x509(ctx_t ctx, X509 *cert, EVP_PKEY *key, ctx_t create_client_context_from_x509(X509 *cert, EVP_PKEY *key, const char *password, - unsigned long &out_error) { + uint64_t &out_error) { out_error = 0; auto ctx = create_client_context(); if (!ctx) { - out_error = static_cast(get_error()); + out_error = get_error(); return nullptr; } @@ -13303,7 +13339,7 @@ ctx_t create_client_context_from_x509(X509 *cert, EVP_PKEY *key, } if (!set_client_cert_pem(ctx, cert_pem.c_str(), key_pem.c_str(), password)) { - out_error = static_cast(get_error()); + out_error = get_error(); free_context(ctx); return nullptr; } diff --git a/vendor/cpp-httplib/httplib.h b/vendor/cpp-httplib/httplib.h index aea6fd308..ac1908f42 100644 --- a/vendor/cpp-httplib/httplib.h +++ b/vendor/cpp-httplib/httplib.h @@ -8,8 +8,8 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.35.0" -#define CPPHTTPLIB_VERSION_NUM "0x002300" +#define CPPHTTPLIB_VERSION "0.37.0" +#define CPPHTTPLIB_VERSION_NUM "0x002500" /* * Platform compatibility check @@ -575,6 +575,14 @@ inline unsigned char to_lower(int c) { return table[(unsigned char)(char)c]; } +inline std::string to_lower(const std::string &s) { + std::string result = s; + std::transform( + result.begin(), result.end(), result.begin(), + [](unsigned char c) { return static_cast(to_lower(c)); }); + return result; +} + inline bool equal(const std::string &a, const std::string &b) { return a.size() == b.size() && std::equal(a.begin(), a.end(), b.begin(), [](char ca, char cb) { @@ -1859,23 +1867,23 @@ public: : res_(std::move(res)), err_(err), request_headers_(std::move(request_headers)), ssl_error_(ssl_error) {} Result(std::unique_ptr &&res, Error err, Headers &&request_headers, - int ssl_error, unsigned long ssl_backend_error) + int ssl_error, uint64_t ssl_backend_error) : res_(std::move(res)), err_(err), request_headers_(std::move(request_headers)), ssl_error_(ssl_error), ssl_backend_error_(ssl_backend_error) {} int ssl_error() const { return ssl_error_; } - unsigned long ssl_backend_error() const { return ssl_backend_error_; } + uint64_t ssl_backend_error() const { return ssl_backend_error_; } private: int ssl_error_ = 0; - unsigned long ssl_backend_error_ = 0; + uint64_t ssl_backend_error_ = 0; #endif #ifdef CPPHTTPLIB_OPENSSL_SUPPORT public: [[deprecated("Use ssl_backend_error() instead")]] - unsigned long ssl_openssl_error() const { + uint64_t ssl_openssl_error() const { return ssl_backend_error_; } #endif @@ -2345,7 +2353,7 @@ protected: bool server_hostname_verification_ = true; std::string ca_cert_pem_; // Store CA cert PEM for redirect transfer int last_ssl_error_ = 0; - unsigned long last_backend_error_ = 0; + uint64_t last_backend_error_ = 0; #endif #ifdef CPPHTTPLIB_OPENSSL_SUPPORT diff --git a/vendor/miniaudio/miniaudio.h b/vendor/miniaudio/miniaudio.h index 24e676bb2..c6d493ee8 100644 --- a/vendor/miniaudio/miniaudio.h +++ b/vendor/miniaudio/miniaudio.h @@ -1,6 +1,6 @@ /* Audio playback and capture library. Choice of public domain or MIT-0. See license statements at the end of this file. -miniaudio - v0.11.24 - 2026-01-17 +miniaudio - v0.11.25 - 2026-03-04 David Reid - mackron@gmail.com @@ -3747,7 +3747,7 @@ extern "C" { #define MA_VERSION_MAJOR 0 #define MA_VERSION_MINOR 11 -#define MA_VERSION_REVISION 24 +#define MA_VERSION_REVISION 25 #define MA_VERSION_STRING MA_XSTRINGIFY(MA_VERSION_MAJOR) "." MA_XSTRINGIFY(MA_VERSION_MINOR) "." MA_XSTRINGIFY(MA_VERSION_REVISION) #if defined(_MSC_VER) && !defined(__clang__) @@ -19358,7 +19358,7 @@ MA_API ma_handle ma_dlopen(ma_log* pLog, const char* filename) #else /* *sigh* It appears there is no ANSI version of LoadPackagedLibrary()... */ WCHAR filenameW[4096]; - if (MultiByteToWideChar(CP_UTF8, 0, filename, -1, filenameW, sizeof(filenameW)) == 0) { + if (MultiByteToWideChar(CP_UTF8, 0, filename, -1, filenameW, ma_countof(filenameW)) == 0) { handle = NULL; } else { handle = (ma_handle)LoadPackagedLibrary(filenameW, 0); @@ -41495,18 +41495,37 @@ Web Audio Backend #ifdef MA_HAS_WEBAUDIO #include -#if (__EMSCRIPTEN_major__ > 3) || (__EMSCRIPTEN_major__ == 3 && (__EMSCRIPTEN_minor__ > 1 || (__EMSCRIPTEN_minor__ == 1 && __EMSCRIPTEN_tiny__ >= 32))) +#ifndef MA_EMSCRIPTEN_MAJOR + #if defined(__EMSCRIPTEN_MAJOR__) + #define MA_EMSCRIPTEN_MAJOR __EMSCRIPTEN_MAJOR__ + #else + #define MA_EMSCRIPTEN_MAJOR __EMSCRIPTEN_major__ + #endif +#endif +#ifndef MA_EMSCRIPTEN_MINOR + #if defined(__EMSCRIPTEN_MINOR__) + #define MA_EMSCRIPTEN_MINOR __EMSCRIPTEN_MINOR__ + #else + #define MA_EMSCRIPTEN_MINOR __EMSCRIPTEN_minor__ + #endif +#endif +#ifndef MA_EMSCRIPTEN_TINY + #if defined(__EMSCRIPTEN_TINY__) + #define MA_EMSCRIPTEN_TINY __EMSCRIPTEN_TINY__ + #else + #define MA_EMSCRIPTEN_TINY __EMSCRIPTEN_tiny__ + #endif +#endif + +#if (MA_EMSCRIPTEN_MAJOR > 3) || (MA_EMSCRIPTEN_MAJOR == 3 && (MA_EMSCRIPTEN_MINOR > 1 || (MA_EMSCRIPTEN_MINOR == 1 && MA_EMSCRIPTEN_TINY >= 32))) #include #define MA_SUPPORT_AUDIO_WORKLETS - #if (__EMSCRIPTEN_major__ > 3) || (__EMSCRIPTEN_major__ == 3 && (__EMSCRIPTEN_minor__ > 1 || (__EMSCRIPTEN_minor__ == 1 && __EMSCRIPTEN_tiny__ >= 70))) + #if (MA_EMSCRIPTEN_MAJOR > 3) || (MA_EMSCRIPTEN_MAJOR == 3 && (MA_EMSCRIPTEN_MINOR > 1 || (MA_EMSCRIPTEN_MINOR == 1 && MA_EMSCRIPTEN_TINY >= 70))) #define MA_SUPPORT_AUDIO_WORKLETS_VARIABLE_BUFFER_SIZE #endif #endif -/* -TODO: Version 0.12: Swap this logic around so that AudioWorklets are used by default. Add MA_NO_AUDIO_WORKLETS. -*/ #if defined(MA_ENABLE_AUDIO_WORKLETS) && defined(MA_SUPPORT_AUDIO_WORKLETS) #define MA_USE_AUDIO_WORKLETS #endif @@ -59243,6 +59262,10 @@ static ma_result ma_data_source_read_pcm_frames_within_range(ma_data_source* pDa ma_uint64 framesRead = 0; ma_bool32 loop = ma_data_source_is_looping(pDataSource); + if (pFramesRead != NULL) { + *pFramesRead = 0; + } + if (pDataSourceBase == NULL) { return MA_AT_END; } @@ -61921,7 +61944,7 @@ extern "C" { #define MA_DR_WAV_XSTRINGIFY(x) MA_DR_WAV_STRINGIFY(x) #define MA_DR_WAV_VERSION_MAJOR 0 #define MA_DR_WAV_VERSION_MINOR 14 -#define MA_DR_WAV_VERSION_REVISION 4 +#define MA_DR_WAV_VERSION_REVISION 5 #define MA_DR_WAV_VERSION_STRING MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_MAJOR) "." MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_MINOR) "." MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_REVISION) #include #define MA_DR_WAVE_FORMAT_PCM 0x1 @@ -80503,6 +80526,13 @@ MA_PRIVATE ma_uint64 ma_dr_wav__read_smpl_to_metadata_obj(ma_dr_wav__metadata_pa MA_DR_WAV_ASSERT(pChunkHeader != NULL); if (pMetadata != NULL && bytesJustRead == sizeof(smplHeaderData)) { ma_uint32 iSampleLoop; + ma_uint32 loopCount; + ma_uint32 calculatedLoopCount; + loopCount = ma_dr_wav_bytes_to_u32(smplHeaderData + 28); + calculatedLoopCount = (pChunkHeader->sizeInBytes - MA_DR_WAV_SMPL_BYTES) / MA_DR_WAV_SMPL_LOOP_BYTES; + if (loopCount != calculatedLoopCount) { + return totalBytesRead; + } pMetadata->type = ma_dr_wav_metadata_type_smpl; pMetadata->data.smpl.manufacturerId = ma_dr_wav_bytes_to_u32(smplHeaderData + 0); pMetadata->data.smpl.productId = ma_dr_wav_bytes_to_u32(smplHeaderData + 4); @@ -80513,7 +80543,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav__read_smpl_to_metadata_obj(ma_dr_wav__metadata_pa pMetadata->data.smpl.smpteOffset = ma_dr_wav_bytes_to_u32(smplHeaderData + 24); pMetadata->data.smpl.sampleLoopCount = ma_dr_wav_bytes_to_u32(smplHeaderData + 28); pMetadata->data.smpl.samplerSpecificDataSizeInBytes = ma_dr_wav_bytes_to_u32(smplHeaderData + 32); - if (pMetadata->data.smpl.sampleLoopCount == (pChunkHeader->sizeInBytes - MA_DR_WAV_SMPL_BYTES) / MA_DR_WAV_SMPL_LOOP_BYTES) { + if (pMetadata->data.smpl.sampleLoopCount == calculatedLoopCount) { pMetadata->data.smpl.pLoops = (ma_dr_wav_smpl_loop*)ma_dr_wav__metadata_get_memory(pParser, sizeof(ma_dr_wav_smpl_loop) * pMetadata->data.smpl.sampleLoopCount, MA_DR_WAV_METADATA_ALIGNMENT); for (iSampleLoop = 0; iSampleLoop < pMetadata->data.smpl.sampleLoopCount; ++iSampleLoop) { ma_uint8 smplLoopData[MA_DR_WAV_SMPL_LOOP_BYTES]; @@ -80534,6 +80564,8 @@ MA_PRIVATE ma_uint64 ma_dr_wav__read_smpl_to_metadata_obj(ma_dr_wav__metadata_pa MA_DR_WAV_ASSERT(pMetadata->data.smpl.pSamplerSpecificData != NULL); ma_dr_wav__metadata_parser_read(pParser, pMetadata->data.smpl.pSamplerSpecificData, pMetadata->data.smpl.samplerSpecificDataSizeInBytes, &totalBytesRead); } + } else { + MA_DR_WAV_ZERO_OBJECT(&pMetadata->data.smpl); } } return totalBytesRead; @@ -83149,19 +83181,13 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; newSample0 += nibble0 * pWav->msadpcm.delta[0]; newSample0 = ma_dr_wav_clamp(newSample0, -32768, 32767); - pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0xF0) >> 4)] * pWav->msadpcm.delta[0]) >> 8; - if (pWav->msadpcm.delta[0] < 16) { - pWav->msadpcm.delta[0] = 16; - } + pWav->msadpcm.delta[0] = (ma_int32)ma_dr_wav_clamp(((ma_int64)adaptationTable[((nibbles & 0xF0) >> 4)] * pWav->msadpcm.delta[0]) >> 8, 16, 0x7FFFFFFF); pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1]; pWav->msadpcm.prevFrames[0][1] = newSample0; newSample1 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; newSample1 += nibble1 * pWav->msadpcm.delta[0]; newSample1 = ma_dr_wav_clamp(newSample1, -32768, 32767); - pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0x0F) >> 0)] * pWav->msadpcm.delta[0]) >> 8; - if (pWav->msadpcm.delta[0] < 16) { - pWav->msadpcm.delta[0] = 16; - } + pWav->msadpcm.delta[0] = (ma_int32)ma_dr_wav_clamp(((ma_int64)adaptationTable[((nibbles & 0x0F) >> 0)] * pWav->msadpcm.delta[0]) >> 8, 16, 0x7FFFFFFF); pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1]; pWav->msadpcm.prevFrames[0][1] = newSample1; pWav->msadpcm.cachedFrames[2] = newSample0; @@ -83176,10 +83202,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; newSample0 += nibble0 * pWav->msadpcm.delta[0]; newSample0 = ma_dr_wav_clamp(newSample0, -32768, 32767); - pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0xF0) >> 4)] * pWav->msadpcm.delta[0]) >> 8; - if (pWav->msadpcm.delta[0] < 16) { - pWav->msadpcm.delta[0] = 16; - } + pWav->msadpcm.delta[0] = (ma_int32)ma_dr_wav_clamp(((ma_int64)adaptationTable[((nibbles & 0xF0) >> 4)] * pWav->msadpcm.delta[0]) >> 8, 16, 0x7FFFFFFF); pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1]; pWav->msadpcm.prevFrames[0][1] = newSample0; if (pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff2Table)) { @@ -83188,10 +83211,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ newSample1 = ((pWav->msadpcm.prevFrames[1][1] * coeff1Table[pWav->msadpcm.predictor[1]]) + (pWav->msadpcm.prevFrames[1][0] * coeff2Table[pWav->msadpcm.predictor[1]])) >> 8; newSample1 += nibble1 * pWav->msadpcm.delta[1]; newSample1 = ma_dr_wav_clamp(newSample1, -32768, 32767); - pWav->msadpcm.delta[1] = (adaptationTable[((nibbles & 0x0F) >> 0)] * pWav->msadpcm.delta[1]) >> 8; - if (pWav->msadpcm.delta[1] < 16) { - pWav->msadpcm.delta[1] = 16; - } + pWav->msadpcm.delta[1] = (ma_int32)ma_dr_wav_clamp(((ma_int64)adaptationTable[((nibbles & 0x0F) >> 0)] * pWav->msadpcm.delta[1]) >> 8, 16, 0x7FFFFFFF); pWav->msadpcm.prevFrames[1][0] = pWav->msadpcm.prevFrames[1][1]; pWav->msadpcm.prevFrames[1][1] = newSample1; pWav->msadpcm.cachedFrames[2] = newSample0; @@ -95825,7 +95845,7 @@ For more information, please refer to =============================================================================== ALTERNATIVE 2 - MIT No Attribution =============================================================================== -Copyright 2025 David Reid +Copyright 2026 David Reid Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in