diff --git a/common/chat.cpp b/common/chat.cpp index 1262c3745..e25d0e2b2 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1106,6 +1106,14 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ common_chat_params data; data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + + if (inputs.add_generation_prompt && string_ends_with(data.prompt, "\n")) { + // This may happen if the model generates content + tool_call, the + // template does not add the model's next turn and confuses the model + // from emitting its proper reasoning token sequence. + data.prompt += "<|turn>model\n"; + } + data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4; data.supports_thinking = true; data.thinking_start_tag = "<|channel>thought"; @@ -1133,7 +1141,8 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ p.rule("thought", p.content(p.literal("<|channel>thought") + p.space() + p.until("") + p.literal(""))); } - auto thought = (p.peek(p.literal("<|channel>")) + p.ref("thought")) | p.negate(p.literal("<|channel>")); + auto consume_empty_channels = p.gbnf(p.zero_or_more(p.literal("<|channel>") + p.negate(p.literal("thought"))), ""); + auto thought = (p.peek(p.literal("<|channel>")) + consume_empty_channels + p.ref("thought")) | p.negate(p.literal("<|channel>")); if (has_response_format) { auto response_format = p.literal("```json") << @@ -1197,12 +1206,16 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ /* max = */ inputs.parallel_tool_calls ? -1 : 1 )); - auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<|tool_call>"}))); + auto scan_to_toolcall = p.rule("scan-to-toolcall", p.until("<|tool_call>")); + auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "", "<|tool_call>"}))); auto message = p.rule("message", thought + content); - return start + p.zero_or_more(message) + tool_call; + return start + p.zero_or_more(message) + scan_to_toolcall + tool_call; } - auto content = p.rule("content", p.content(p.until("<|channel>"))); + // Gemma 4 may emit an extra <|channel>thought\n at the end of the content. It may + // also emit a single trailing token. Consume all complete reasoning blocks and + // then stop at the first unmatched token. + auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", ""}))); auto message = p.rule("message", thought + content); return start + p.one_or_more(message); }); @@ -1671,6 +1684,173 @@ static common_chat_params common_chat_params_init_gigachat_v3( return data; } +static common_chat_params common_chat_params_init_deepseek_v3_2(const common_chat_template & tmpl, + const autoparser::generation_params & inputs) { + common_chat_params data; + + data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.supports_thinking = true; + data.thinking_start_tag = ""; + data.thinking_end_tag = ""; + data.preserved_tokens = { + "|DSML|", + "", + "", + }; + + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object(); + auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; + auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE); + + const std::string DSML = "|DSML|"; + const std::string THINK_START = ""; + const std::string THINK_END = ""; + const std::string FC_START = "<" + DSML + "function_calls>"; + const std::string FC_END = ""; + const std::string INVOKE_START = "<" + DSML + "invoke"; + const std::string INVOKE_END = ""; + const std::string PARAM_START = "<" + DSML + "parameter"; + const std::string PARAM_END = ""; + + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START); + auto end = p.end(); + + auto reasoning = p.eps(); + if (extract_reasoning && inputs.enable_thinking) { + reasoning = p.optional(THINK_START + p.reasoning(p.until(THINK_END)) + THINK_END); + } else if (extract_reasoning) { + // Thinking disabled but reasoning extraction requested: the generation prompt + // contains an empty pair that must still be consumed. + reasoning = p.optional(p.literal(THINK_START) + p.until(THINK_END) + p.literal(THINK_END)); + } + + if (has_response_format) { + auto response_format = p.rule("response-format", + p.literal("```json") + p.space() + + p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)) + + p.space() + p.literal("```")); + return generation_prompt + reasoning + response_format + end; + } + + if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { + return generation_prompt + reasoning + p.content(p.rest()) + end; + } + + auto tool_choice = p.choice(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto params = function.contains("parameters") ? function.at("parameters") : json::object(); + const auto & props = params.contains("properties") ? params.at("properties") : json::object(); + + std::set required; + if (params.contains("required")) { + params.at("required").get_to(required); + } + + auto schema_info = common_schema_info(); + schema_info.resolve_refs(params); + + std::vector required_parsers; + std::vector optional_parsers; + for (const auto & [param_name, param_schema] : props.items()) { + bool is_required = required.find(param_name) != required.end(); + bool is_string = schema_info.resolves_to_string(param_schema); + + auto arg = p.tool_arg( + p.tool_arg_open( + p.literal(PARAM_START + " name=\"") + + p.tool_arg_name(p.literal(param_name)) + + p.literal("\" string=\"" + std::string(is_string ? "true" : "false") + "\">")) + + (is_string + ? p.tool_arg_string_value(p.until(PARAM_END)) + : p.tool_arg_json_value(p.schema(p.json(), + "tool-" + name + "-arg-" + param_name + "-schema", + param_schema, false))) + + p.tool_arg_close(p.literal(PARAM_END))); + + auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg); + if (is_required) { + required_parsers.push_back(named_arg); + } else { + optional_parsers.push_back(named_arg); + } + } + + common_peg_parser args_seq = p.eps(); + for (size_t i = 0; i < required_parsers.size(); i++) { + if (i > 0) { + args_seq = args_seq + p.space(); + } + args_seq = args_seq + required_parsers[i]; + } + + if (!optional_parsers.empty()) { + common_peg_parser any_opt = p.choice(); + for (const auto & opt : optional_parsers) { + any_opt |= opt; + } + args_seq = args_seq + p.repeat(p.space() + any_opt, 0, -1); + } + + common_peg_parser invoke_body = args_seq; + auto func_parser = p.tool( + p.tool_open(p.literal(INVOKE_START + " name=\"") + + p.tool_name(p.literal(name)) + p.literal("\">\n")) + + invoke_body + p.space() + + p.tool_close(p.literal(INVOKE_END))); + + tool_choice |= p.rule("tool-" + name, func_parser); + }); + + auto require_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED; + + common_peg_parser tool_calls = p.eps(); + if (inputs.parallel_tool_calls) { + tool_calls = p.trigger_rule("tool-call", + p.literal(FC_START) + p.space() + tool_choice + + p.zero_or_more(p.space() + tool_choice) + p.space() + p.literal(FC_END)); + } else { + tool_calls = p.trigger_rule("tool-call", + p.literal(FC_START) + p.space() + tool_choice + p.space() + p.literal(FC_END)); + } + + if (!require_tools) { + tool_calls = p.optional(tool_calls); + } + + auto content_before_tools = p.content(p.until(FC_START)); + return generation_prompt + reasoning + content_before_tools + tool_calls + end; + }); + + data.parser = parser.save(); + + if (include_grammar) { + data.grammar_lazy = !(has_response_format || (has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED)); + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + auto schema = function.contains("parameters") ? function.at("parameters") : json::object(); + builder.resolve_refs(schema); + }); + if (has_response_format) { + auto schema = inputs.json_schema; + builder.resolve_refs(schema); + } + parser.build_grammar(builder, data.grammar_lazy); + }); + + data.grammar_triggers = { + { COMMON_GRAMMAR_TRIGGER_TYPE_WORD, FC_START }, + }; + } + + return data; +} + namespace workaround { static void map_developer_role_to_system(json & messages) { @@ -1942,6 +2122,15 @@ std::optional common_chat_try_specialized_template( return common_chat_params_init_gigachat_v3(tmpl, params); } + // DeepSeek V3.2 format detection: template defines dsml_token and uses it for tool calls. + // The template source contains the token as a variable assignment, not as a literal in markup. + if (src.find("dsml_token") != std::string::npos && + src.find("function_calls") != std::string::npos && + src.find("DSML") != std::string::npos) { + LOG_DBG("Using specialized template: DeepSeek V3.2\n"); + return common_chat_params_init_deepseek_v3_2(tmpl, params); + } + // Gemma4 format detection if (src.find("'<|tool_call>call:'") != std::string::npos) { if (src.find("{#- OpenAI Chat Completions:") == std::string::npos) { diff --git a/common/peg-parser.cpp b/common/peg-parser.cpp index 59fa4c5c5..e37c1ce80 100644 --- a/common/peg-parser.cpp +++ b/common/peg-parser.cpp @@ -890,6 +890,10 @@ struct parser_executor { } return result; } + + common_peg_parse_result operator()(const common_peg_gbnf_parser & p) { + return arena.parse(p.child, ctx, start_pos); + } }; common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const { @@ -957,7 +961,8 @@ void common_peg_arena::resolve_refs() { std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v) { + std::is_same_v || + std::is_same_v) { p.child = resolve_ref(p.child); } else if constexpr (std::is_same_v) { p.child = resolve_ref(p.child); @@ -1036,6 +1041,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id return "Not(" + dump_impl(p.child, visited) + ")"; } else if constexpr (std::is_same_v) { return "Atomic(" + dump_impl(p.child, visited) + ")"; + } else if constexpr (std::is_same_v) { + return "Gbnf(" + p.grammar + ", " + dump_impl(p.child, visited) + ")"; } else if constexpr (std::is_same_v) { return "Any"; } else if constexpr (std::is_same_v) { @@ -1565,6 +1572,7 @@ static std::unordered_set collect_reachable_rules( std::is_same_v || std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { visit(p.child); } else if constexpr (std::is_same_v) { @@ -1651,10 +1659,13 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo } else if constexpr (std::is_same_v) { std::string s; for (const auto & child : p.children) { + auto child_gbnf = to_gbnf(child); + if (child_gbnf.empty()) { + continue; + } if (!s.empty()) { s += " "; } - auto child_gbnf = to_gbnf(child); const auto & child_parser = effective_parser(child); if (std::holds_alternative(child_parser) || std::holds_alternative(child_parser)) { @@ -1754,6 +1765,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo return to_gbnf(p.child); } else if constexpr (std::is_same_v) { return to_gbnf(p.child); + } else if constexpr (std::is_same_v) { + return p.grammar; } else { static_assert(is_always_false_v); } @@ -1888,6 +1901,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant & {"child", p.child}, {"tag", p.tag} }; + } else if constexpr (std::is_same_v) { + return json{{"type", "gbnf"}, {"child", p.child}, {"grammar", p.grammar}}; } }, variant); } @@ -2050,6 +2065,16 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json }; } + if (type == "gbnf") { + if (!j.contains("child") || !j.contains("grammar")) { + throw std::runtime_error("gbnf parser missing required fields"); + } + return common_peg_gbnf_parser{ + j["child"].get(), + j["grammar"].get(), + }; + } + throw std::runtime_error("Unknown parser type: " + type); } diff --git a/common/peg-parser.h b/common/peg-parser.h index f242fc421..b6bb05214 100644 --- a/common/peg-parser.h +++ b/common/peg-parser.h @@ -270,6 +270,11 @@ struct common_peg_tag_parser { std::string tag; }; +struct common_peg_gbnf_parser { + common_peg_parser_id child; + std::string grammar; +}; + // Variant holding all parser types using common_peg_parser_variant = std::variant< common_peg_epsilon_parser, @@ -290,7 +295,8 @@ using common_peg_parser_variant = std::variant< common_peg_rule_parser, common_peg_ref_parser, common_peg_atomic_parser, - common_peg_tag_parser + common_peg_tag_parser, + common_peg_gbnf_parser >; class common_peg_arena { @@ -504,6 +510,10 @@ class common_peg_parser_builder { // Unlike rules, you can tag multiple nodes with the same tag. common_peg_parser tag(const std::string & tag, const common_peg_parser & p) { return add(common_peg_tag_parser{p.id(), tag}); } + // Wraps a child parser but emits a custom GBNF grammar string instead of + // the child's grammar. Parsing delegates entirely to the child. + common_peg_parser gbnf(const common_peg_parser & p, const std::string & grammar) { return add(common_peg_gbnf_parser{p, grammar}); } + void set_root(const common_peg_parser & p); common_peg_arena build(); diff --git a/common/sampling.cpp b/common/sampling.cpp index 2f60be194..526f036ff 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -287,8 +287,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st } } - // reasoning budget sampler - if (!params.reasoning_budget_start.empty() && !params.reasoning_budget_end.empty()) { + // reasoning budget sampler (skip when budget is unlimited unless a lazy grammar is active, which needs rbudget for thinking-block suppression) + if (!params.reasoning_budget_start.empty() && !params.reasoning_budget_end.empty() && (params.grammar_lazy || params.reasoning_budget_tokens >= 0)) { rbudget = common_reasoning_budget_init( vocab, params.reasoning_budget_start, diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index e09db59cf..64d811faf 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -783,6 +783,7 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo const int8x16_t q4_lo_1 = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits_1, m4b)); const int8x16_t q4_hi_1 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_1, 4)); +#if defined(__ARM_FEATURE_DOTPROD) const int8x16_t q8_0a = vld1q_s8(y[2*ib].qs); const int8x16_t q8_0b = vld1q_s8(y[2*ib].qs + 16); const int8x16_t q8_lo_0 = vcombine_s8(vget_low_s8(q8_0a), vget_low_s8(q8_0b)); @@ -794,15 +795,40 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo const int8x16_t q8_hi_1 = vcombine_s8(vget_high_s8(q8_1a), vget_high_s8(q8_1b)); const int32x4_t p0 = vaddq_s32( - ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_0, q8_lo_0), - ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_0, q8_hi_0)); + vdotq_s32(vdupq_n_s32(0), q4_lo_0, q8_lo_0), + vdotq_s32(vdupq_n_s32(0), q4_hi_0, q8_hi_0)); const int32x4_t p1 = vaddq_s32( - ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_1, q8_lo_1), - ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_1, q8_hi_1)); + vdotq_s32(vdupq_n_s32(0), q4_lo_1, q8_lo_1), + vdotq_s32(vdupq_n_s32(0), q4_hi_1, q8_hi_1)); - const int32x4_t sums = vpaddq_s32(p0, p1); + const int32x4_t sumi = vpaddq_s32(p0, p1); +#else + const int8x8_t q4_0_lo = vget_low_s8(q4_lo_0); + const int8x8_t q4_0_hi = vget_low_s8(q4_hi_0); + const int8x8_t q4_1_lo = vget_high_s8(q4_lo_0); + const int8x8_t q4_1_hi = vget_high_s8(q4_hi_0); + const int8x8_t q4_2_lo = vget_low_s8(q4_lo_1); + const int8x8_t q4_2_hi = vget_low_s8(q4_hi_1); + const int8x8_t q4_3_lo = vget_high_s8(q4_lo_1); + const int8x8_t q4_3_hi = vget_high_s8(q4_hi_1); + + const int8x8_t q8_0_lo = vld1_s8(y[2*ib].qs); + const int8x8_t q8_0_hi = vld1_s8(y[2*ib].qs + 8); + const int8x8_t q8_1_lo = vld1_s8(y[2*ib].qs + 16); + const int8x8_t q8_1_hi = vld1_s8(y[2*ib].qs + 24); + const int8x8_t q8_2_lo = vld1_s8(y[2*ib+1].qs); + const int8x8_t q8_2_hi = vld1_s8(y[2*ib+1].qs + 8); + const int8x8_t q8_3_lo = vld1_s8(y[2*ib+1].qs + 16); + const int8x8_t q8_3_hi = vld1_s8(y[2*ib+1].qs + 24); + + const int32x4_t sumi = (int32x4_t){ + vaddvq_s32(ggml_nvfp4_dot8(q4_0_lo, q8_0_lo, q4_0_hi, q8_0_hi)), + vaddvq_s32(ggml_nvfp4_dot8(q4_1_lo, q8_1_lo, q4_1_hi, q8_1_hi)), + vaddvq_s32(ggml_nvfp4_dot8(q4_2_lo, q8_2_lo, q4_2_hi, q8_2_hi)), + vaddvq_s32(ggml_nvfp4_dot8(q4_3_lo, q8_3_lo, q4_3_hi, q8_3_hi)), + }; +#endif - // Decode 4 UE4M3 scales to f32 and multiply with q8 scales const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d); const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d); const float32x4_t nvsc = { @@ -813,7 +839,7 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo }; const float32x4_t scales = vmulq_f32(nvsc, (float32x4_t){dy0, dy0, dy1, dy1}); - acc = vfmaq_f32(acc, vcvtq_f32_s32(sums), scales); + acc = vfmaq_f32(acc, vcvtq_f32_s32(sumi), scales); } sumf = vaddvq_f32(acc); #else diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index 88a9c9ec0..5d1ca5ffc 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -306,6 +306,7 @@ inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) { #if !defined(__ARM_FEATURE_DOTPROD) +// NOTE: this fallback produces the same total sum as native vdotq_s32 but with different per-lane grouping — do not use when individual lane values matter. inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) { const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b)); const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); @@ -319,6 +320,15 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) #endif // !defined(__ARM_FEATURE_DOTPROD) +static inline int32x4_t ggml_nvfp4_dot8(const int8x8_t q4_lo, const int8x8_t q8_lo, + const int8x8_t q4_hi, const int8x8_t q8_hi) { + const int16x8_t p_lo = vmull_s8(q4_lo, q8_lo); + const int16x8_t p_hi = vmull_s8(q4_hi, q8_hi); + const int32x4_t sum_lo = vpaddlq_s16(p_lo); + const int32x4_t sum_hi = vpaddlq_s16(p_hi); + return vaddq_s32(sum_lo, sum_hi); +} + #endif // defined(__ARM_NEON) #ifdef __wasm_simd128__ diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 391e33ad9..80bcf3830 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3095,6 +3095,10 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec case GGML_TYPE_MXFP4: lut_size = 4*16; break; + case GGML_TYPE_NVFP4: + // Same kvalues budget as MXFP4 plus ue4m3_fp32_lut[128] (types.glsl, DATA_A_NVFP4). + lut_size = 4*16 + 128u * (uint32_t)sizeof(float); + break; default: break; } @@ -3574,6 +3578,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_NVFP4], matmul_nvfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) GGML_ASSERT(device->subgroup_ballot); @@ -3604,6 +3609,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) #undef CREATE_MM #undef CREATE_MM2 } else @@ -3667,6 +3673,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4], matmul_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } else { CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0].f32acc, matmul_q1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -3690,6 +3697,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4].f32acc, matmul_nvfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } GGML_ASSERT(device->subgroup_ballot); @@ -3724,6 +3732,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); #undef CREATE_MM2 #undef CREATE_MM } else @@ -3789,6 +3798,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4], matmul_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3835,6 +3845,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3880,6 +3891,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_nvfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3955,6 +3967,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4].f32acc, matmul_nvfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3999,6 +4012,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4].f32acc, matmul_id_subgroup_nvfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); } else { CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); @@ -4026,6 +4040,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4].f32acc, matmul_id_nvfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); } } // reusing CREATE_MM from the fp32 path @@ -4124,6 +4139,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32", arr_dmmv_iq4_xs_f32_f32_len[reduc16], arr_dmmv_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f32_f32", arr_dmmv_nvfp4_f32_f32_len[reduc16], arr_dmmv_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); @@ -4149,6 +4165,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32", arr_dmmv_iq4_xs_f16_f32_len[reduc16], arr_dmmv_iq4_xs_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[reduc16], arr_dmmv_iq4_nl_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[reduc16], arr_dmmv_mxfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f16_f32", arr_dmmv_nvfp4_f16_f32_len[reduc16], arr_dmmv_nvfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -4200,6 +4217,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", arr_dmmv_id_iq4_xs_f32_f32_len[reduc16], arr_dmmv_id_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", arr_dmmv_id_iq4_nl_f32_f32_len[reduc16], arr_dmmv_id_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", arr_dmmv_id_mxfp4_f32_f32_len[reduc16], arr_dmmv_id_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_NVFP4], "mul_mat_vec_id_nvfp4_f32", arr_dmmv_id_nvfp4_f32_f32_len[reduc16], arr_dmmv_id_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -4255,6 +4273,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_NVFP4], "dequant_nvfp4", dequant_nvfp4_len, dequant_nvfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); // get_rows ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -4281,6 +4300,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_NVFP4], "get_rows_nvfp4", get_rows_nvfp4_len, get_rows_nvfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32], "get_rows_i32", get_rows_i32_len, get_rows_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -4307,6 +4327,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); @@ -6119,6 +6140,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -6191,6 +6213,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -6257,6 +6280,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -6348,6 +6372,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -6417,6 +6442,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -15411,6 +15437,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return false; @@ -15526,6 +15553,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_I32: return true; default: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp index 06df50952..6a6921474 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -4,7 +4,7 @@ #include "generic_unary_head.glsl" #include "dequant_funcs.glsl" -#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) +#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) || defined(DATA_A_NVFP4) // 16 invocations needed for init_iq_shmem layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; #else diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index ede1275cf..88d07d2df 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -450,6 +450,25 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_NVFP4) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint sub = iqs >> 4; + const float d = ue4m3_to_fp32(data_a[a_offset + ib].d[sub]); + const uint j = iqs & 7; + const uint shift = (iqs & 8) >> 1; // 0 or 4 + const uint vui0 = uint(data_a[a_offset + ib].qs[sub * 8u + j]); + const uint vui1 = uint(data_a[a_offset + ib].qs[sub * 8u + j + 1]); + const uint qs0 = (vui0 >> shift) & 0xF; + const uint qs1 = (vui1 >> shift) & 0xF; + return vec2(float(kvalues_mxfp4[qs0]), float(kvalues_mxfp4[qs1])) * d * 0.5; +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const vec2 v0 = dequantize(ib, iqs, a_offset); + const vec2 v1 = dequantize(ib, iqs + 2u, a_offset); + return vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + #if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) vec2 get_dm(uint ib, uint a_offset) { return vec2(0, 0); @@ -484,6 +503,12 @@ vec2 get_dm(uint ib, uint a_offset) { } #endif +#if defined(DATA_A_NVFP4) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1.0, 0.0); +} +#endif + #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) vec2 get_dm(uint ib, uint a_offset) { const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index 03035f281..c582aba87 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -697,6 +697,24 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords } #endif +#if defined(DATA_A_NVFP4) +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVFP4 { + block_nvfp4 block; +}; + +float16_t dequantFuncNVFP4(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + const uint sub = (idx & 0x30) >> 4; + const uint iqs = ((idx & 0x30) >> 1) + (idx & 0x7); + const uint shift = (idx & 0x8) >> 1; + const float d = ue4m3_to_fp32(bl.block.d[sub]); + uint qs = uint(bl.block.qs[iqs]); + qs = (qs >> shift) & 0xF; + return float16_t(kvalues_mxfp4[qs] * d * 0.5); +} +#endif + #if defined(DATA_A_Q1_0) #define dequantFuncA dequantFuncQ1_0 #elif defined(DATA_A_Q4_0) @@ -743,6 +761,8 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords #define dequantFuncA dequantFuncIQ4_NL #elif defined(DATA_A_MXFP4) #define dequantFuncA dequantFuncMXFP4 +#elif defined(DATA_A_NVFP4) +#define dequantFuncA dequantFuncNVFP4 #elif defined(DATA_A_F32) #define dequantFuncA dequantFuncF32 #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp new file mode 100644 index 000000000..689089160 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp @@ -0,0 +1,32 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_nvfp4 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + init_iq_shmem(gl_WorkGroupSize); + + const uint tid = gl_LocalInvocationID.x % 64; + const uint sub = tid / 16; + const uint ir = tid % 16; + const uint ib = 16 * i + ir; + if (ib >= p.nel / 64) { + return; + } + + const uint q_idx = 8 * sub; + const uint b_idx = 1024 * i + 64 * ir + 16 * sub; + + const float d = ue4m3_to_fp32(data_a[ib].d[sub]); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF])); + data_b[b_idx + l + 8] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4])); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 219bd6080..6e4a29d2f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -501,6 +501,23 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin kvalues_mxfp4[vui2 & 0xF] * d); buf_a[buf_idx + 8] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d, kvalues_mxfp4[vui2 >> 4] * d); +#elif defined(DATA_A_NVFP4) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + // lo and hi nibbles are 8 elements apart, which doesn't quite line up with + // how the thread mapping and buf_idx calculation works for other types. + const uint buf_idx = col * SHMEM_STRIDE + (row & 3) + (row & ~3) * 2; + + const uint ib = idx / 16u; + const uint sub = (idx & 0xC) >> 2; + const uint iqs = (idx & 0xF) * 2; + const float d = ue4m3_to_fp32(data_a[ib].d[sub]) * 0.5; + const uint vui = uint(data_a[ib].qs[iqs]); + const uint vui2 = uint(data_a[ib].qs[iqs+1]); + + buf_a[buf_idx ] = FLOAT_TYPEV2(kvalues_mxfp4[vui & 0xF] * d, + kvalues_mxfp4[vui2 & 0xF] * d); + buf_a[buf_idx + 4] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d, + kvalues_mxfp4[vui2 >> 4] * d); #endif } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 1fb592fb8..4bcd97756 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -1713,6 +1713,22 @@ struct block_mxfp4 #define A_TYPE block_mxfp4 #endif +#define QUANT_K_NVFP4 64 +#define QUANT_R_NVFP4 1 + +struct block_nvfp4 +{ + uint8_t d[QUANT_K_NVFP4 / 16]; + uint8_t qs[QUANT_K_NVFP4 / 2]; +}; + +#if defined(DATA_A_NVFP4) +#define QUANT_K QUANT_K_NVFP4 +#define QUANT_R QUANT_R_NVFP4 +#define QUANT_AUXF 1 +#define A_TYPE block_nvfp4 +#endif + #if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS) const int8_t kvalues_iq4nl_const[16] = { int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10), @@ -1732,7 +1748,7 @@ void init_iq_shmem(uvec3 wgsize) } #endif -#if defined(DATA_A_MXFP4) +#if defined(DATA_A_MXFP4) || defined(DATA_A_NVFP4) const int8_t kvalues_mxfp4_const[16] = { int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12), int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12), @@ -1740,6 +1756,24 @@ const int8_t kvalues_mxfp4_const[16] = { shared int8_t kvalues_mxfp4[16]; +#if defined(DATA_A_NVFP4) +// UE4M3 scale in NVFP4 blocks use only 7 bits; sign (bit 7) is always zero. +shared float ue4m3_fp32_lut[128]; + +float ue4m3_to_fp32_build(uint u) { + if (u == 0u || u == 127u) { + return 0.0; + } + const uint exp = (u >> 3) & 15u; + const uint man = u & 7u; + if (exp == 0u) { + return float(man) * (1.0 / 512.0); + } + const uint bits = (exp + 120u) << 23 | (man << 20); + return uintBitsToFloat(bits); +} +#endif + #define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) { @@ -1747,6 +1781,11 @@ void init_iq_shmem(uvec3 wgsize) for (uint i = gl_LocalInvocationIndex.x; i < kvalues_mxfp4.length(); i += wgsize.x) { kvalues_mxfp4[i] = kvalues_mxfp4_const[i]; } +#if defined(DATA_A_NVFP4) + for (uint i = gl_LocalInvocationIndex.x; i < 128u; i += wgsize.x) { + ue4m3_fp32_lut[i] = ue4m3_to_fp32_build(i); + } +#endif barrier(); } #endif @@ -1783,6 +1822,12 @@ float e8m0_to_fp32(uint8_t x) { return uintBitsToFloat(bits); } +#if defined(DATA_A_NVFP4) +float ue4m3_to_fp32(uint8_t x) { + return ue4m3_fp32_lut[uint(x)]; +} +#endif + #if BDA #extension GL_EXT_buffer_reference : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index b1db59fee..8492dd534 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -80,6 +80,7 @@ const std::vector type_names = { "iq4_xs", "iq4_nl", "mxfp4", + "nvfp4", "bf16", }; @@ -573,7 +574,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c std::string load_vec_quant = "2"; if ((tname == "q1_0") || (tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) load_vec_quant = "8"; - else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4")) + else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4") || (tname == "nvfp4")) load_vec_quant = "4"; if (tname == "bf16") { diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index eb0198616..5da48d61b 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -98,6 +98,7 @@ add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M" add_test_audio "ggml-org/Voxtral-Mini-3B-2507-GGUF:Q4_K_M" add_test_audio "ggml-org/LFM2-Audio-1.5B-GGUF:Q8_0" add_test_audio "ggml-org/gemma-4-E2B-it-GGUF:Q8_0" --jinja +add_test_audio "ggml-org/Qwen3-ASR-0.6B-GGUF:Q8_0" # to test the big models, run: ./tests.sh big if [ "$RUN_BIG_TESTS" = true ]; then diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index ed5e306fc..e3f243902 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -1433,6 +1433,60 @@ json convert_responses_to_chatcmpl(const json & response_body) { return chatcmpl_body; } +json convert_transcriptions_to_chatcmpl( + const json & inp_body, + const std::map & in_files, + std::vector & out_files) { + // TODO @ngxson : this function may need to be improved in the future + // handle input files + out_files.clear(); + auto it = in_files.find("file"); + if (it != in_files.end()) { + out_files.push_back(it->second); + } else { + throw std::invalid_argument("No input file found for transcription"); + } + + // handle input data + std::string prompt = json_value(inp_body, "prompt", std::string()); + std::string language = json_value(inp_body, "language", std::string()); + std::string response_format = json_value(inp_body, "response_format", std::string("json")); + if (response_format != "json") { + throw std::invalid_argument("Only 'json' response_format is supported for transcription"); + } + if (prompt.empty()) { + prompt = "Transcribe audio to text"; + } + if (!language.empty()) { + prompt += string_format(" (language: %s)", language.c_str()); + } + prompt += mtmd_default_marker(); + + json chatcmpl_body = inp_body; // copy all fields + chatcmpl_body["messages"] = json::array({ + { + {"role", "user"}, + {"content", prompt}, + }, + }); + + // because input from form-data, everything is string, we need to correct the types here + std::string stream = json_value(inp_body, "stream", std::string("false")); + chatcmpl_body["stream"] = stream == "true"; + + if (inp_body.contains("max_tokens")) { + std::string inp = inp_body["max_tokens"].get(); + chatcmpl_body["max_tokens"] = std::stoul(inp); + } + + if (inp_body.contains("temperature")) { + std::string inp = inp_body["temperature"].get(); + chatcmpl_body["temperature"] = std::stof(inp); + } + + return chatcmpl_body; +} + json convert_anthropic_to_oai(const json & body) { json oai_body; diff --git a/tools/server/server-common.h b/tools/server/server-common.h index 213ae52bb..440ebc597 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -305,6 +305,12 @@ json oaicompat_chat_params_parse( // convert OpenAI Responses API format to OpenAI Chat Completions API format json convert_responses_to_chatcmpl(const json & body); +// convert OpenAI transcriptions API format to OpenAI Chat Completions API format +json convert_transcriptions_to_chatcmpl( + const json & body, + const std::map & in_files, + std::vector & out_files); + // convert Anthropic Messages API format to OpenAI Chat Completions API format json convert_anthropic_to_oai(const json & body); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index b31981c56..e134b3cfb 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -3732,6 +3732,33 @@ void server_routes::init_routes() { TASK_RESPONSE_TYPE_OAI_RESP); }; + this->post_transcriptions_oai = [this](const server_http_req & req) { + auto res = create_response(); + + if (!meta->has_mtmd || !meta->chat_params.allow_audio) { + res->error(format_error_response("The current model does not support audio input.", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + std::vector files; + json body = convert_transcriptions_to_chatcmpl( + json::parse(req.body), + req.files, + files); + SRV_DBG("%s\n", "Request converted: OpenAI Transcriptions -> OpenAI Chat Completions"); + SRV_DBG("converted request: %s\n", body.dump().c_str()); + json body_parsed = oaicompat_chat_params_parse( + body, + meta->chat_params, + files); + return handle_completions_impl( + req, + SERVER_TASK_TYPE_COMPLETION, + body_parsed, + files, + TASK_RESPONSE_TYPE_OAI_ASR); + }; + this->post_anthropic_messages = [this](const server_http_req & req) { auto res = create_response(); std::vector files; diff --git a/tools/server/server-context.h b/tools/server/server-context.h index 6ea9afc0a..6856043fa 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -111,6 +111,7 @@ struct server_routes { server_http_context::handler_t post_completions_oai; server_http_context::handler_t post_chat_completions; server_http_context::handler_t post_responses_oai; + server_http_context::handler_t post_transcriptions_oai; server_http_context::handler_t post_anthropic_messages; server_http_context::handler_t post_anthropic_count_tokens; server_http_context::handler_t post_apply_template; diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp index 37e7cbe9c..83f656f5c 100644 --- a/tools/server/server-http.cpp +++ b/tools/server/server-http.cpp @@ -428,6 +428,7 @@ void server_http_context::get(const std::string & path, const server_http_contex req.path, build_query_string(req), req.body, + {}, req.is_connection_closed }); server_http_res_ptr response = handler(*request); @@ -437,12 +438,39 @@ void server_http_context::get(const std::string & path, const server_http_contex void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const { pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) { + std::string body = req.body; + std::map files; + + if (req.is_multipart_form_data()) { + // translate text fields to a JSON object and use it as the body + json form_json = json::object(); + for (const auto & [key, field] : req.form.fields) { + if (form_json.contains(key)) { + // if the key already exists, convert it to an array + if (!form_json[key].is_array()) { + json existing_value = form_json[key]; + form_json[key] = json::array({existing_value}); + } + form_json[key].push_back(field.content); + } else { + form_json[key] = field.content; + } + } + body = form_json.dump(); + + // populate files from multipart form + for (const auto & [key, file] : req.form.files) { + files[key] = raw_buffer(file.content.begin(), file.content.end()); + } + } + server_http_req_ptr request = std::make_unique(server_http_req{ get_params(req), get_headers(req), req.path, build_query_string(req), - req.body, + body, + std::move(files), req.is_connection_closed }); server_http_res_ptr response = handler(*request); diff --git a/tools/server/server-http.h b/tools/server/server-http.h index f8a174c44..68ae2170c 100644 --- a/tools/server/server-http.h +++ b/tools/server/server-http.h @@ -5,6 +5,8 @@ #include #include #include +#include +#include struct common_params; @@ -32,6 +34,7 @@ struct server_http_res { // unique pointer, used by set_chunked_content_provider // httplib requires the stream provider to be stored in heap using server_http_res_ptr = std::unique_ptr; +using raw_buffer = std::vector; struct server_http_req { std::map params; // path_params + query_params @@ -39,6 +42,7 @@ struct server_http_req { std::string path; std::string query_string; // query parameters string (e.g. "action=save") std::string body; + std::map files; // used for file uploads (form data) const std::function & should_stop; std::string get_param(const std::string & key, const std::string & def = "") const { diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 6a06171d7..0312f098a 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -725,6 +725,8 @@ json server_task_result_cmpl_final::to_json() { return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); case TASK_RESPONSE_TYPE_OAI_RESP: return stream ? to_json_oaicompat_resp_stream() : to_json_oaicompat_resp(); + case TASK_RESPONSE_TYPE_OAI_ASR: + return to_json_oaicompat_asr(); case TASK_RESPONSE_TYPE_ANTHROPIC: return stream ? to_json_anthropic_stream() : to_json_anthropic(); default: @@ -1102,6 +1104,21 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() { return server_sent_events; } +json server_task_result_cmpl_final::to_json_oaicompat_asr() { + json event = json { + {"type", "transcript.text.done"}, + {"text", content}, + {"usage", json { + {"type", "tokens"}, + {"input_tokens", n_prompt_tokens}, + {"output_tokens", n_decoded}, + {"total_tokens", n_decoded + n_prompt_tokens}, + {"input_tokens_details", json { {"cached_tokens", n_prompt_tokens_cache} }}, + }}, + }; + return event; +} + json server_task_result_cmpl_final::to_json_anthropic() { std::string stop_reason = "max_tokens"; if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { @@ -1400,6 +1417,8 @@ json server_task_result_cmpl_partial::to_json() { return to_json_oaicompat_chat(); case TASK_RESPONSE_TYPE_OAI_RESP: return to_json_oaicompat_resp(); + case TASK_RESPONSE_TYPE_OAI_ASR: + return to_json_oaicompat_asr(); case TASK_RESPONSE_TYPE_ANTHROPIC: return to_json_anthropic(); default: @@ -1650,6 +1669,14 @@ json server_task_result_cmpl_partial::to_json_oaicompat_resp() { return events; } +json server_task_result_cmpl_partial::to_json_oaicompat_asr() { + json event = json { + {"type", "transcript.text.delta"}, + {"delta", content}, + }; + return event; +} + json server_task_result_cmpl_partial::to_json_anthropic() { json events = json::array(); bool first = (n_decoded == 1); diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 243e47a8e..95f39207b 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -34,6 +34,7 @@ enum task_response_type { TASK_RESPONSE_TYPE_OAI_CHAT, TASK_RESPONSE_TYPE_OAI_CMPL, TASK_RESPONSE_TYPE_OAI_RESP, + TASK_RESPONSE_TYPE_OAI_ASR, // transcriptions API TASK_RESPONSE_TYPE_OAI_EMBD, TASK_RESPONSE_TYPE_ANTHROPIC, }; @@ -401,6 +402,8 @@ struct server_task_result_cmpl_final : server_task_result { json to_json_oaicompat_resp_stream(); + json to_json_oaicompat_asr(); + json to_json_anthropic(); json to_json_anthropic_stream(); @@ -457,6 +460,8 @@ struct server_task_result_cmpl_partial : server_task_result { json to_json_oaicompat_resp(); + json to_json_oaicompat_asr(); + json to_json_anthropic(); }; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index b9e320d9c..fe640b978 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -145,6 +145,7 @@ int main(int argc, char ** argv) { routes.post_completions_oai = models_routes->proxy_post; routes.post_chat_completions = models_routes->proxy_post; routes.post_responses_oai = models_routes->proxy_post; + routes.post_transcriptions_oai = models_routes->proxy_post; routes.post_anthropic_messages = models_routes->proxy_post; routes.post_anthropic_count_tokens = models_routes->proxy_post; routes.post_infill = models_routes->proxy_post; @@ -160,48 +161,51 @@ int main(int argc, char ** argv) { routes.post_slots = models_routes->proxy_post; // custom routes for router - routes.get_props = models_routes->get_router_props; - routes.get_models = models_routes->get_router_models; - ctx_http.post("/models/load", ex_wrapper(models_routes->post_router_models_load)); - ctx_http.post("/models/unload", ex_wrapper(models_routes->post_router_models_unload)); + routes.get_props = models_routes->get_router_props; + routes.get_models = models_routes->get_router_models; + + ctx_http.post("/models/load", ex_wrapper(models_routes->post_router_models_load)); + ctx_http.post("/models/unload", ex_wrapper(models_routes->post_router_models_unload)); } - ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) - ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) - ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics)); - ctx_http.get ("/props", ex_wrapper(routes.get_props)); - ctx_http.post("/props", ex_wrapper(routes.post_props)); - ctx_http.post("/api/show", ex_wrapper(routes.get_api_show)); - ctx_http.get ("/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check) - ctx_http.get ("/v1/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check) - ctx_http.get ("/api/tags", ex_wrapper(routes.get_models)); // ollama specific endpoint. public endpoint (no API key check) - ctx_http.post("/completion", ex_wrapper(routes.post_completions)); // legacy - ctx_http.post("/completions", ex_wrapper(routes.post_completions)); - ctx_http.post("/v1/completions", ex_wrapper(routes.post_completions_oai)); - ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions)); - ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions)); - ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint - ctx_http.post("/v1/responses", ex_wrapper(routes.post_responses_oai)); - ctx_http.post("/responses", ex_wrapper(routes.post_responses_oai)); - ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API + ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) + ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) + ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics)); + ctx_http.get ("/props", ex_wrapper(routes.get_props)); + ctx_http.post("/props", ex_wrapper(routes.post_props)); + ctx_http.post("/api/show", ex_wrapper(routes.get_api_show)); + ctx_http.get ("/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check) + ctx_http.get ("/v1/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check) + ctx_http.get ("/api/tags", ex_wrapper(routes.get_models)); // ollama specific endpoint. public endpoint (no API key check) + ctx_http.post("/completion", ex_wrapper(routes.post_completions)); // legacy + ctx_http.post("/completions", ex_wrapper(routes.post_completions)); + ctx_http.post("/v1/completions", ex_wrapper(routes.post_completions_oai)); + ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions)); + ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions)); + ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint + ctx_http.post("/v1/responses", ex_wrapper(routes.post_responses_oai)); + ctx_http.post("/responses", ex_wrapper(routes.post_responses_oai)); + ctx_http.post("/v1/audio/transcriptions", ex_wrapper(routes.post_transcriptions_oai)); + ctx_http.post("/audio/transcriptions", ex_wrapper(routes.post_transcriptions_oai)); + ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API ctx_http.post("/v1/messages/count_tokens", ex_wrapper(routes.post_anthropic_count_tokens)); // anthropic token counting - ctx_http.post("/infill", ex_wrapper(routes.post_infill)); - ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy - ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings)); - ctx_http.post("/v1/embeddings", ex_wrapper(routes.post_embeddings_oai)); - ctx_http.post("/rerank", ex_wrapper(routes.post_rerank)); - ctx_http.post("/reranking", ex_wrapper(routes.post_rerank)); - ctx_http.post("/v1/rerank", ex_wrapper(routes.post_rerank)); - ctx_http.post("/v1/reranking", ex_wrapper(routes.post_rerank)); - ctx_http.post("/tokenize", ex_wrapper(routes.post_tokenize)); - ctx_http.post("/detokenize", ex_wrapper(routes.post_detokenize)); - ctx_http.post("/apply-template", ex_wrapper(routes.post_apply_template)); + ctx_http.post("/infill", ex_wrapper(routes.post_infill)); + ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy + ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings)); + ctx_http.post("/v1/embeddings", ex_wrapper(routes.post_embeddings_oai)); + ctx_http.post("/rerank", ex_wrapper(routes.post_rerank)); + ctx_http.post("/reranking", ex_wrapper(routes.post_rerank)); + ctx_http.post("/v1/rerank", ex_wrapper(routes.post_rerank)); + ctx_http.post("/v1/reranking", ex_wrapper(routes.post_rerank)); + ctx_http.post("/tokenize", ex_wrapper(routes.post_tokenize)); + ctx_http.post("/detokenize", ex_wrapper(routes.post_detokenize)); + ctx_http.post("/apply-template", ex_wrapper(routes.post_apply_template)); // LoRA adapters hotswap - ctx_http.get ("/lora-adapters", ex_wrapper(routes.get_lora_adapters)); - ctx_http.post("/lora-adapters", ex_wrapper(routes.post_lora_adapters)); + ctx_http.get ("/lora-adapters", ex_wrapper(routes.get_lora_adapters)); + ctx_http.post("/lora-adapters", ex_wrapper(routes.post_lora_adapters)); // Save & load slots - ctx_http.get ("/slots", ex_wrapper(routes.get_slots)); - ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots)); + ctx_http.get ("/slots", ex_wrapper(routes.get_slots)); + ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots)); // CORS proxy (EXPERIMENTAL, only used by the Web UI for MCP) if (params.webui_mcp_proxy) { SRV_WRN("%s", "-----------------\n"); diff --git a/vendor/cpp-httplib/CMakeLists.txt b/vendor/cpp-httplib/CMakeLists.txt index f140d6d3a..78dc48332 100644 --- a/vendor/cpp-httplib/CMakeLists.txt +++ b/vendor/cpp-httplib/CMakeLists.txt @@ -39,7 +39,7 @@ if (LLAMA_BUILD_BORINGSSL) set(FIPS OFF CACHE BOOL "Enable FIPS (BoringSSL)") set(BORINGSSL_GIT "https://boringssl.googlesource.com/boringssl" CACHE STRING "BoringSSL git repository") - set(BORINGSSL_VERSION "0.20260327.0" CACHE STRING "BoringSSL version") + set(BORINGSSL_VERSION "0.20260413.0" CACHE STRING "BoringSSL version") message(STATUS "Fetching BoringSSL version ${BORINGSSL_VERSION}")