diff --git a/common/chat.cpp b/common/chat.cpp index 7c071560f..159d625de 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -544,6 +544,26 @@ bool common_chat_templates_was_explicit(const struct common_chat_templates * tmp return tmpls->has_explicit_template; } +// LFM2 format detection: template uses <|tool_list_start|>[...]<|tool_list_end|> around the tool list +// and <|tool_call_start|>[...]<|tool_call_end|> around each tool call +static bool is_lfm2_template(const std::string & src) { + return src.find("<|tool_list_start|>") != std::string::npos && + src.find("<|tool_list_end|>") != std::string::npos; +} + +common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates) { + common_chat_prompt_preset asr_preset; + asr_preset.system = ""; + asr_preset.user = "Transcribe audio to text"; + + if (chat_templates && chat_templates->template_default && is_lfm2_template(chat_templates->template_default->source())) { + asr_preset.system = "Perform ASR."; + asr_preset.user = ""; + } + + return asr_preset; +} + std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant) { if (!variant.empty()) { if (variant == "tool_use") { @@ -2053,10 +2073,7 @@ std::optional common_chat_try_specialized_template( return common_chat_params_init_kimi_k2(tmpl, params); } - // LFM2 format detection: template uses <|tool_list_start|>[...]<|tool_list_end|> around the tool list - // and <|tool_call_start|>[...]<|tool_call_end|> around each tool call - if (src.find("<|tool_list_start|>") != std::string::npos && - src.find("<|tool_list_end|>") != std::string::npos) { + if (is_lfm2_template(src)) { LOG_DBG("Using specialized template: LFM2\n"); return common_chat_params_init_lfm2(tmpl, params); } @@ -2365,4 +2382,3 @@ std::map common_chat_templates_get_caps(const common_chat_tem GGML_ASSERT(chat_templates->template_default != nullptr); return chat_templates->template_default->caps.to_map(); } - diff --git a/common/chat.h b/common/chat.h index 9122f2967..01a47b383 100644 --- a/common/chat.h +++ b/common/chat.h @@ -274,3 +274,11 @@ std::optional common_chat_try_specialized_template( const common_chat_template & tmpl, const std::string & src, autoparser::generation_params & params); + +// specialized per-task preset +struct common_chat_prompt_preset { + std::string system; + std::string user; +}; + +common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates); diff --git a/tools/server/server-chat.cpp b/tools/server/server-chat.cpp index 4fe81553c..ef586d1e1 100644 --- a/tools/server/server-chat.cpp +++ b/tools/server/server-chat.cpp @@ -535,6 +535,7 @@ json server_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) { json convert_transcriptions_to_chatcmpl( const json & inp_body, + const common_chat_templates * tmpls, const std::map & in_files, std::vector & out_files) { // TODO @ngxson : this function may need to be improved in the future @@ -548,27 +549,29 @@ json convert_transcriptions_to_chatcmpl( } // handle input data - std::string prompt = json_value(inp_body, "prompt", std::string()); - std::string language = json_value(inp_body, "language", std::string()); + std::string prompt = json_value(inp_body, "prompt", std::string()); + std::string language = json_value(inp_body, "language", std::string()); std::string response_format = json_value(inp_body, "response_format", std::string("json")); if (response_format != "json") { throw std::invalid_argument("Only 'json' response_format is supported for transcription"); } + const common_chat_prompt_preset preset = common_chat_get_asr_prompt(tmpls); if (prompt.empty()) { - prompt = "Transcribe audio to text"; + prompt = preset.user; } if (!language.empty()) { prompt += string_format(" (language: %s)", language.c_str()); } prompt += get_media_marker(); + json messages = json::array(); + if (!preset.system.empty()) { + messages.push_back({{"role", "system"}, {"content", preset.system}}); + } + messages.push_back({{"role", "user"}, {"content", prompt}}); + json chatcmpl_body = inp_body; // copy all fields - chatcmpl_body["messages"] = json::array({ - { - {"role", "user"}, - {"content", prompt}, - }, - }); + chatcmpl_body["messages"] = messages; // because input from form-data, everything is string, we need to correct the types here std::string stream = json_value(inp_body, "stream", std::string("false")); diff --git a/tools/server/server-chat.h b/tools/server/server-chat.h index ecb8907c4..5c5b792cf 100644 --- a/tools/server/server-chat.h +++ b/tools/server/server-chat.h @@ -18,6 +18,7 @@ json server_chat_convert_anthropic_to_oai(const json & body); // convert OpenAI transcriptions API format to OpenAI Chat Completions API format json convert_transcriptions_to_chatcmpl( const json & body, + const common_chat_templates * tmpls, const std::map & in_files, std::vector & out_files); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index b8c05cd80..67a92755b 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -3807,6 +3807,7 @@ void server_routes::init_routes() { std::vector files; json body = convert_transcriptions_to_chatcmpl( json::parse(req.body), + meta->chat_params.tmpls.get(), req.files, files); SRV_DBG("%s\n", "Request converted: OpenAI Transcriptions -> OpenAI Chat Completions");