mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-22 19:47:49 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .github/workflows/server-self-hosted.yml # CMakeLists.txt # CODEOWNERS # ci/run.sh # cmake/llama-config.cmake.in # common/chat.cpp # examples/sycl/start-svr.sh # examples/sycl/test.sh # examples/sycl/win-start-svr.bat # examples/sycl/win-test.bat # ggml/src/ggml-sycl/ggml-sycl.cpp # ggml/src/ggml-sycl/vecdotq.hpp # ggml/src/ggml-vulkan/CMakeLists.txt # scripts/wc2wt.sh # tests/test-backend-ops.cpp # tests/test-chat.cpp
This commit is contained in:
commit
fecf2dc3fa
79 changed files with 946 additions and 305 deletions
|
|
@ -2810,7 +2810,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
[](common_params & params, int value) {
|
||||
params.embd_normalize = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_DEBUG}));
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG}));
|
||||
add_opt(common_arg(
|
||||
{"--embd-output-format"}, "FORMAT",
|
||||
"empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)",
|
||||
|
|
|
|||
|
|
@ -43,11 +43,33 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
|
|||
const autoparser & autoparser) {
|
||||
// Create the result structure
|
||||
common_chat_params data;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = autoparser.preserved_tokens;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.generation_prompt = common_chat_template_generation_prompt(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = autoparser.preserved_tokens;
|
||||
|
||||
auto parser = autoparser.build_parser(inputs);
|
||||
std::string parser_generation_prompt = data.generation_prompt;
|
||||
|
||||
if (inputs.continue_final_message != COMMON_CHAT_CONTINUATION_NONE && !inputs.continue_msg.empty()) {
|
||||
// Build up generation prompt manually
|
||||
const auto & msg = inputs.continue_msg;
|
||||
|
||||
if (!autoparser.reasoning.start.empty()) {
|
||||
data.generation_prompt = data.generation_prompt.substr(0, data.generation_prompt.find(autoparser.reasoning.start));
|
||||
data.generation_prompt += autoparser.reasoning.start + msg.reasoning_content;
|
||||
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
|
||||
data.generation_prompt += autoparser.reasoning.end;
|
||||
}
|
||||
}
|
||||
|
||||
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
|
||||
data.generation_prompt += msg.render_content();
|
||||
}
|
||||
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto parser = autoparser.build_parser(inputs, parser_generation_prompt);
|
||||
data.parser = parser.save();
|
||||
|
||||
// Build grammar if tools are present
|
||||
|
|
@ -87,7 +109,7 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
|
|||
return data;
|
||||
}
|
||||
|
||||
common_peg_arena autoparser::build_parser(const generation_params & inputs) const {
|
||||
common_peg_arena autoparser::build_parser(const generation_params & inputs, const std::string & generation_prompt) const {
|
||||
if (!analysis_complete) {
|
||||
throw std::invalid_argument("Cannot call build_parser on autoparser without performing analysis first, call analyze_template(...)");
|
||||
}
|
||||
|
|
@ -121,7 +143,7 @@ common_peg_arena autoparser::build_parser(const generation_params & inputs) cons
|
|||
} else {
|
||||
parser = content.build_parser(ctx);
|
||||
}
|
||||
return pure_content ? p.prefix(inputs.generation_prompt, reasoning.start) + parser : p.prefix(inputs.generation_prompt, reasoning.start) << parser;
|
||||
return pure_content ? p.prefix(generation_prompt, reasoning.start) + parser : p.prefix(generation_prompt, reasoning.start) << parser;
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -60,16 +60,21 @@ struct generation_params {
|
|||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO;
|
||||
bool stream = true;
|
||||
std::string grammar;
|
||||
bool add_generation_prompt = false;
|
||||
bool enable_thinking = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
std::string generation_prompt;
|
||||
bool add_generation_prompt = false;
|
||||
common_chat_continuation continue_final_message = COMMON_CHAT_CONTINUATION_NONE;
|
||||
common_chat_msg continue_msg;
|
||||
bool enable_thinking = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
json extra_context;
|
||||
bool add_bos = false;
|
||||
bool add_eos = false;
|
||||
bool is_inference = true;
|
||||
bool add_inference = false;
|
||||
bool mark_input = true; // whether to mark input strings in the jinja context
|
||||
|
||||
bool has_continuation() const {
|
||||
return continue_final_message != COMMON_CHAT_CONTINUATION_NONE && !continue_msg.empty();
|
||||
}
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
|
|
@ -386,7 +391,7 @@ struct autoparser {
|
|||
void analyze_template(const common_chat_template & tmpl);
|
||||
|
||||
// Build the PEG parser for this template
|
||||
common_peg_arena build_parser(const generation_params & inputs) const;
|
||||
common_peg_arena build_parser(const generation_params & inputs, const std::string & generation_prompt) const;
|
||||
|
||||
private:
|
||||
// Collect tokens from entire analysis to preserve
|
||||
|
|
|
|||
|
|
@ -358,35 +358,7 @@ void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
|
|||
if (is_potential_container) {
|
||||
value_content = normalize_container_value(value_content);
|
||||
}
|
||||
|
||||
// Try to parse as JSON value (number, bool, null, object, array)
|
||||
try {
|
||||
ordered_json parsed = ordered_json::parse(value_content);
|
||||
if (parsed.is_string()) {
|
||||
// Don't add closing quote yet (added by arg_close) for monotonic streaming
|
||||
std::string escaped = parsed.dump();
|
||||
if (!escaped.empty() && escaped.back() == '"') {
|
||||
escaped.pop_back();
|
||||
}
|
||||
value_to_add = escaped;
|
||||
closing_quote_pending = true;
|
||||
} else {
|
||||
// Non-string values: use raw content to preserve whitespace for monotonicity
|
||||
value_to_add = value_content;
|
||||
}
|
||||
} catch (...) {
|
||||
if (node.is_partial && is_potential_container) {
|
||||
// Partial container: pass through the already-normalized content
|
||||
value_to_add = value_content;
|
||||
} else {
|
||||
// Not valid JSON - treat as string value
|
||||
if (!closing_quote_pending) {
|
||||
value_to_add = "\"";
|
||||
closing_quote_pending = true;
|
||||
}
|
||||
value_to_add += escape_json_string_inner(value_content);
|
||||
}
|
||||
}
|
||||
value_to_add += value_content;
|
||||
}
|
||||
|
||||
args_target() += value_to_add;
|
||||
|
|
@ -813,7 +785,7 @@ common_peg_parser common_chat_peg_builder::prefix(const std::string & s, const s
|
|||
if (delimiter.empty()) {
|
||||
return literal(s);
|
||||
}
|
||||
return literal(s.substr(0, s.rfind(delimiter)));
|
||||
return literal(s.substr(0, s.find(delimiter)));
|
||||
}
|
||||
|
||||
common_peg_parser common_chat_peg_builder::optspace(const std::string & tag) {
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class common_chat_peg_builder : public common_peg_parser_builder {
|
|||
|
||||
// Use for schema-declared string types - won't be treated as potential JSON container
|
||||
common_peg_parser tool_arg_string_value(const common_peg_parser & p) { return tag(TOOL_ARG_STRING_VALUE, p); }
|
||||
common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_VALUE, p)); }
|
||||
common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return tag(TOOL_ARG_VALUE, p); }
|
||||
|
||||
|
||||
// Return a parser that parses the prefix of a string, up to a given delimiter.
|
||||
|
|
|
|||
256
common/chat.cpp
256
common/chat.cpp
|
|
@ -81,6 +81,26 @@ static bool has_content_or_tool_calls(const common_chat_msg & msg) {
|
|||
return !msg.content.empty() || !msg.tool_calls.empty();
|
||||
}
|
||||
|
||||
std::string common_chat_msg::render_content(const std::string & delimiter) const {
|
||||
if (!content.empty() && !content_parts.empty()) {
|
||||
throw std::runtime_error("Cannot specify both content and content_parts");
|
||||
}
|
||||
if (!content.empty()) {
|
||||
return content;
|
||||
}
|
||||
|
||||
std::string text;
|
||||
for (const auto & part : content_parts) {
|
||||
if (part.type == "text") {
|
||||
if (!text.empty()) {
|
||||
text += delimiter;
|
||||
}
|
||||
text += part.text;
|
||||
}
|
||||
}
|
||||
return text;
|
||||
}
|
||||
|
||||
json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const {
|
||||
if (!content.empty() && !content_parts.empty()) {
|
||||
throw std::runtime_error("Cannot specify both content and content_parts");
|
||||
|
|
@ -465,6 +485,23 @@ std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & too
|
|||
#include "common/unicode.h"
|
||||
#include "peg-parser.cpp"
|
||||
#include "chat-peg-parser.cpp"
|
||||
|
||||
common_chat_continuation common_chat_continuation_parse(const nlohmann::ordered_json & value) {
|
||||
if (value.is_boolean() && value.get<bool>()) {
|
||||
return COMMON_CHAT_CONTINUATION_AUTO;
|
||||
}
|
||||
if (value.is_string()) {
|
||||
auto value_str = value.get<std::string>();
|
||||
if (value_str == "reasoning_content") {
|
||||
return COMMON_CHAT_CONTINUATION_REASONING;
|
||||
}
|
||||
if (value_str == "content") {
|
||||
return COMMON_CHAT_CONTINUATION_CONTENT;
|
||||
}
|
||||
}
|
||||
return COMMON_CHAT_CONTINUATION_NONE;
|
||||
}
|
||||
|
||||
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
||||
if (use_jinja) {
|
||||
try {
|
||||
|
|
@ -825,6 +862,36 @@ std::string common_chat_template_direct_apply(
|
|||
return common_chat_template_direct_apply_impl(tmpl, inputs, std::nullopt, std::nullopt, std::nullopt);
|
||||
}
|
||||
|
||||
static std::string common_chat_template_generation_prompt_impl(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs,
|
||||
const std::optional<json> & messages_override = std::nullopt,
|
||||
const std::optional<json> & tools_override = std::nullopt,
|
||||
const std::optional<json> & additional_context = std::nullopt) {
|
||||
|
||||
auto adjusted_messages = messages_override ? *messages_override : inputs.messages;
|
||||
|
||||
autoparser::generation_params params = inputs;
|
||||
params.add_generation_prompt = false;
|
||||
params.continue_final_message = COMMON_CHAT_CONTINUATION_NONE;
|
||||
std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params, adjusted_messages, tools_override, additional_context);
|
||||
params.add_generation_prompt = true;
|
||||
std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params, adjusted_messages, tools_override, additional_context);
|
||||
|
||||
size_t prefix_len = 0;
|
||||
size_t min_size = std::min(no_gen_prompt.size(), gen_prompt.size());
|
||||
while (prefix_len < min_size && no_gen_prompt[prefix_len] == gen_prompt[prefix_len]) {
|
||||
prefix_len++;
|
||||
}
|
||||
return gen_prompt.substr(prefix_len);
|
||||
}
|
||||
|
||||
std::string common_chat_template_generation_prompt(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
return common_chat_template_generation_prompt_impl(tmpl, inputs, std::nullopt, std::nullopt, std::nullopt);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
|
@ -877,6 +944,7 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
|||
data.thinking_start_tag = "[THINK]";
|
||||
data.thinking_end_tag = "[/THINK]";
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs, /* messages_override = */ adjusted_messages);
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override = */ adjusted_messages);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = {
|
||||
"[THINK]",
|
||||
|
|
@ -885,8 +953,19 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
|||
"[ARGS]",
|
||||
};
|
||||
|
||||
if (inputs.has_continuation()) {
|
||||
const auto & msg = inputs.continue_msg;
|
||||
|
||||
data.generation_prompt = "[THINK]" + msg.reasoning_content;
|
||||
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
|
||||
data.generation_prompt += "[/THINK]" + msg.render_content();
|
||||
}
|
||||
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto generation_prompt = p.prefix(inputs.generation_prompt, "[THINK]");
|
||||
auto generation_prompt = p.eps();
|
||||
auto reasoning =
|
||||
extract_reasoning ? p.optional("[THINK]" + p.reasoning(p.until("[/THINK]")) + "[/THINK]") : p.eps();
|
||||
|
||||
|
|
@ -977,6 +1056,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
|||
}
|
||||
|
||||
data.prompt = prompt;
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
|
||||
|
|
@ -986,6 +1066,18 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
|||
"<|channel|>", "<|constrain|>", "<|message|>", "<|start|>", "<|end|>",
|
||||
};
|
||||
|
||||
// Adjust prompt for continuation
|
||||
if (inputs.has_continuation()) {
|
||||
const auto & msg = inputs.continue_msg;
|
||||
|
||||
data.generation_prompt = "<|start|>assistant<|channel|>analysis<|message|>" + msg.reasoning_content;
|
||||
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
|
||||
data.generation_prompt += "<|end|><|start|>assistant<|channel|>final<|message|>" + msg.render_content();
|
||||
}
|
||||
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object();
|
||||
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
|
||||
|
|
@ -1094,12 +1186,14 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
|||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs);
|
||||
|
||||
if (inputs.add_generation_prompt && string_ends_with(data.prompt, "<turn|>\n")) {
|
||||
// This may happen if the model generates content + tool_call, the
|
||||
// template does not add the model's next turn and confuses the model
|
||||
// from emitting its proper reasoning token sequence.
|
||||
data.prompt += "<|turn>model\n";
|
||||
data.generation_prompt = "<|turn>model\n";
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
|
||||
|
|
@ -1115,13 +1209,25 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
|||
"<|turn>",
|
||||
};
|
||||
|
||||
if (inputs.has_continuation()) {
|
||||
const auto & msg = inputs.continue_msg;
|
||||
|
||||
data.generation_prompt = string_ends_with(data.prompt, "<turn|>\n") ? "<|turn>model\n" : "";
|
||||
data.generation_prompt += "<|channel>thought\n" + msg.reasoning_content;
|
||||
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
|
||||
data.generation_prompt += "<channel|>" + msg.render_content();
|
||||
}
|
||||
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object();
|
||||
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto start = p.rule("start", p.prefix(inputs.generation_prompt, "<|channel>"));
|
||||
auto start = p.rule("start", p.optional(p.literal("<|turn>model\n")));
|
||||
|
||||
if (extract_reasoning) {
|
||||
p.rule("thought", p.literal("<|channel>thought") + p.space() + p.reasoning(p.until("<channel|>")) + p.literal("<channel|>"));
|
||||
|
|
@ -1238,15 +1344,22 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
|||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = {
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = {
|
||||
">>>all",
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
|
||||
if (inputs.has_continuation()) {
|
||||
const auto & msg = inputs.continue_msg;
|
||||
data.generation_prompt = "<|start_header_id|>assistant<|end_header_id|>\n\n>>>all\n" + msg.render_content();
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
// Functionary v3.2 format:
|
||||
// - Normal content: >>>all\n{content}
|
||||
|
|
@ -1258,7 +1371,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
|||
// When no tools, content goes until end
|
||||
auto content_until_tool = p.literal("all\n") + p.content(p.until(">>>"));
|
||||
auto content_until_end = p.literal("all\n") + p.content(p.rest());
|
||||
auto generation_prompt = p.literal(inputs.generation_prompt);
|
||||
auto generation_prompt = p.literal("<|start_header_id|>assistant<|end_header_id|>\n\n>>>");
|
||||
|
||||
// If no tools or tool_choice is NONE, just parse content
|
||||
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
|
|
@ -1332,9 +1445,10 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
|||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
"<|tool_calls_section_begin|>",
|
||||
"<|tool_calls_section_end|>",
|
||||
|
|
@ -1357,10 +1471,22 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
|||
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
const std::string GEN_PROMPT = "<|im_assistant|>assistant<|im_middle|>";
|
||||
|
||||
data.thinking_start_tag = THINK_START;
|
||||
data.thinking_end_tag = THINK_END;
|
||||
|
||||
if (inputs.has_continuation()) {
|
||||
const auto & msg = inputs.continue_msg;
|
||||
|
||||
data.generation_prompt = GEN_PROMPT + THINK_START + msg.reasoning_content;
|
||||
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
|
||||
data.generation_prompt += THINK_END + msg.render_content();
|
||||
}
|
||||
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
// Kimi K2 Thinking format:
|
||||
// - Reasoning: <think>{reasoning}</think>
|
||||
|
|
@ -1380,7 +1506,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
|||
auto reasoning = extract_reasoning ? p.optional(THINK_START + p.reasoning(
|
||||
p.until_one_of({ THINK_END, "<|tool_calls_section_begin|>", "<|tool_call_begin|>" })) +
|
||||
p.optional(p.literal(THINK_END))) : p.eps();
|
||||
auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START);
|
||||
auto generation_prompt = p.literal(GEN_PROMPT);
|
||||
|
||||
|
||||
// Content only parser (no tools)
|
||||
|
|
@ -1456,6 +1582,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
|||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
|
|
@ -1475,12 +1602,24 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
|||
const std::string TOOL_CALL_END = "<|tool_call_end|>";
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
const std::string GEN_PROMPT = "<|im_start|>assistant\n";
|
||||
|
||||
data.thinking_start_tag = THINK_START;
|
||||
data.thinking_end_tag = THINK_END;
|
||||
|
||||
if (inputs.has_continuation()) {
|
||||
const auto & msg = inputs.continue_msg;
|
||||
|
||||
data.generation_prompt = GEN_PROMPT + THINK_START + msg.reasoning_content;
|
||||
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
|
||||
data.generation_prompt += THINK_END + msg.render_content();
|
||||
}
|
||||
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START);
|
||||
auto generation_prompt = p.literal(GEN_PROMPT);
|
||||
auto end = p.end();
|
||||
|
||||
auto reasoning = p.eps();
|
||||
|
|
@ -1535,6 +1674,7 @@ static common_chat_params common_chat_params_init_lfm2_5(const common_chat_templ
|
|||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
|
|
@ -1550,12 +1690,24 @@ static common_chat_params common_chat_params_init_lfm2_5(const common_chat_templ
|
|||
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
const std::string GEN_PROMPT = "<|im_start|>assistant\n";
|
||||
|
||||
data.thinking_start_tag = THINK_START;
|
||||
data.thinking_end_tag = THINK_END;
|
||||
|
||||
if (inputs.has_continuation()) {
|
||||
const auto & msg = inputs.continue_msg;
|
||||
|
||||
data.generation_prompt = GEN_PROMPT + THINK_START + msg.reasoning_content;
|
||||
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
|
||||
data.generation_prompt += THINK_END + msg.render_content();
|
||||
}
|
||||
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START);
|
||||
auto generation_prompt = p.literal(GEN_PROMPT);
|
||||
auto end = p.end();
|
||||
|
||||
auto reasoning = p.eps();
|
||||
|
|
@ -1606,6 +1758,7 @@ static common_chat_params common_chat_params_init_gigachat_v3(
|
|||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = false;
|
||||
data.preserved_tokens = {
|
||||
|
|
@ -1613,6 +1766,12 @@ static common_chat_params common_chat_params_init_gigachat_v3(
|
|||
"<|role_sep|>\n",
|
||||
};
|
||||
|
||||
if (inputs.has_continuation()) {
|
||||
const auto & msg = inputs.continue_msg;
|
||||
data.generation_prompt = "assistant<|role_sep|>\n" + msg.render_content();
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
const auto *tool_call_start_prefix = "<|message_sep|>\n\nfunction call<|role_sep|>\n";
|
||||
|
|
@ -1648,7 +1807,7 @@ static common_chat_params common_chat_params_init_gigachat_v3(
|
|||
ret = p.content(p.rest());
|
||||
}
|
||||
|
||||
return p.literal(inputs.generation_prompt) + ret;
|
||||
return p.literal("assistant<|role_sep|>\n") + ret;
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
|
@ -1676,12 +1835,13 @@ static common_chat_params common_chat_params_init_deepseek_v3_2(const common_cha
|
|||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "<think>";
|
||||
data.thinking_end_tag = "</think>";
|
||||
data.preserved_tokens = {
|
||||
data.preserved_tokens = {
|
||||
"|DSML|",
|
||||
"<think>",
|
||||
"</think>",
|
||||
|
|
@ -1701,9 +1861,21 @@ static common_chat_params common_chat_params_init_deepseek_v3_2(const common_cha
|
|||
const std::string INVOKE_END = "</" + DSML + "invoke>";
|
||||
const std::string PARAM_START = "<" + DSML + "parameter";
|
||||
const std::string PARAM_END = "</" + DSML + "parameter>";
|
||||
const std::string GEN_PROMPT = "<|Assistant|>";
|
||||
|
||||
if (inputs.has_continuation()) {
|
||||
const auto & msg = inputs.continue_msg;
|
||||
|
||||
data.generation_prompt = GEN_PROMPT + THINK_START + msg.reasoning_content;
|
||||
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
|
||||
data.generation_prompt += THINK_END + msg.render_content();
|
||||
}
|
||||
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START);
|
||||
auto generation_prompt = p.literal(GEN_PROMPT);
|
||||
auto end = p.end();
|
||||
|
||||
auto reasoning = p.eps();
|
||||
|
|
@ -2130,21 +2302,6 @@ std::optional<common_chat_params> common_chat_try_specialized_template(
|
|||
return std::nullopt;
|
||||
}
|
||||
|
||||
static std::string common_chat_templates_generation_prompt(const common_chat_template & tmpl, const autoparser::generation_params & inputs) {
|
||||
autoparser::generation_params params = inputs;
|
||||
params.add_generation_prompt = false;
|
||||
std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
params.add_generation_prompt = true;
|
||||
std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
|
||||
size_t prefix_len = 0;
|
||||
size_t min_size = std::min(no_gen_prompt.size(), gen_prompt.size());
|
||||
while (prefix_len < min_size && no_gen_prompt[prefix_len] == gen_prompt[prefix_len]) {
|
||||
prefix_len++;
|
||||
}
|
||||
return gen_prompt.substr(prefix_len);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs) {
|
||||
autoparser::generation_params params;
|
||||
|
|
@ -2163,6 +2320,27 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
|||
params.add_bos = tmpls->add_bos;
|
||||
params.add_eos = tmpls->add_eos;
|
||||
|
||||
params.continue_final_message = inputs.continue_final_message;
|
||||
if (params.continue_final_message != COMMON_CHAT_CONTINUATION_NONE) {
|
||||
params.add_generation_prompt = false;
|
||||
|
||||
if (!inputs.messages.empty()) {
|
||||
// Render messages[:-1] and store continuation message separately
|
||||
params.continue_msg = inputs.messages.back();
|
||||
params.messages.erase(params.messages.size() - 1);
|
||||
}
|
||||
|
||||
if (params.continue_final_message == COMMON_CHAT_CONTINUATION_AUTO && !inputs.messages.empty()) {
|
||||
// Resolve based on message content
|
||||
params.continue_final_message = COMMON_CHAT_CONTINUATION_CONTENT;
|
||||
if (!params.continue_msg.reasoning_content.empty() &&
|
||||
params.continue_msg.content.empty() &&
|
||||
params.continue_msg.content_parts.empty()) {
|
||||
params.continue_final_message = COMMON_CHAT_CONTINUATION_REASONING;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (src.find("<|channel|>") == std::string::npos) {
|
||||
// map developer to system for all models except for GPT-OSS
|
||||
workaround::map_developer_role_to_system(params.messages);
|
||||
|
|
@ -2183,8 +2361,6 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
|||
workaround::func_args_not_string(params.messages);
|
||||
}
|
||||
|
||||
params.generation_prompt = common_chat_templates_generation_prompt(tmpl, params);
|
||||
|
||||
params.extra_context = common_chat_extra_context();
|
||||
for (auto el : inputs.chat_template_kwargs) {
|
||||
params.extra_context[el.first] = json::parse(el.second);
|
||||
|
|
@ -2214,17 +2390,16 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
|||
auto params_copy = params;
|
||||
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, params_copy);
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, params);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.generation_prompt = params.generation_prompt;
|
||||
auto parser = build_chat_peg_parser([¶ms](common_chat_peg_builder &p) {
|
||||
return p.prefix(params.generation_prompt) << p.content(p.rest());
|
||||
auto parser = build_chat_peg_parser([&data](common_chat_peg_builder &p) {
|
||||
return p.literal(data.generation_prompt) << p.content(p.rest());
|
||||
});
|
||||
data.parser = parser.save();
|
||||
return data;
|
||||
}
|
||||
|
||||
if (auto result = common_chat_try_specialized_template(tmpl, src, params)) {
|
||||
result->generation_prompt = params.generation_prompt;
|
||||
return *result;
|
||||
}
|
||||
|
||||
|
|
@ -2238,7 +2413,6 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
|||
auto_params.thinking_start_tag = trim_whitespace(autoparser.reasoning.start);
|
||||
auto_params.thinking_end_tag = trim_whitespace(autoparser.reasoning.end);
|
||||
}
|
||||
auto_params.generation_prompt = params.generation_prompt;
|
||||
common_peg_arena arena;
|
||||
arena.load(auto_params.parser);
|
||||
LOG_DBG("%s: generated parser:\n%s\n\nparser generation prompt: %s\n", __func__, arena.dump(arena.root()).c_str(), auto_params.generation_prompt.c_str());
|
||||
|
|
|
|||
|
|
@ -89,6 +89,8 @@ struct common_chat_msg {
|
|||
|
||||
nlohmann::ordered_json to_json_oaicompat(bool concat_typed_text = false) const;
|
||||
|
||||
std::string render_content(const std::string & delimiter = "\n\n") const;
|
||||
|
||||
bool empty() const {
|
||||
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() &&
|
||||
tool_name.empty() && tool_call_id.empty();
|
||||
|
|
@ -164,12 +166,22 @@ enum common_chat_format {
|
|||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
||||
|
||||
// Continuation method provided via `continue_final_message`
|
||||
enum common_chat_continuation {
|
||||
COMMON_CHAT_CONTINUATION_NONE,
|
||||
COMMON_CHAT_CONTINUATION_AUTO,
|
||||
COMMON_CHAT_CONTINUATION_REASONING,
|
||||
COMMON_CHAT_CONTINUATION_CONTENT,
|
||||
};
|
||||
|
||||
struct common_chat_templates_inputs {
|
||||
std::vector<common_chat_msg> messages;
|
||||
std::string grammar;
|
||||
std::string json_schema;
|
||||
bool add_generation_prompt = true;
|
||||
bool use_jinja = true;
|
||||
bool add_generation_prompt = true;
|
||||
common_chat_continuation continue_final_message = COMMON_CHAT_CONTINUATION_NONE;
|
||||
bool use_jinja = true;
|
||||
// Parameters below only supported when use_jinja is true
|
||||
std::vector<common_chat_tool> tools;
|
||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
|
|
@ -207,6 +219,7 @@ struct common_chat_parser_params {
|
|||
bool reasoning_in_content = false;
|
||||
std::string generation_prompt;
|
||||
bool parse_tool_calls = true;
|
||||
bool echo = false; // Include assistant prefilled msg in output
|
||||
bool debug = false; // Enable debug output for PEG parser
|
||||
common_peg_arena parser = {};
|
||||
common_chat_parser_params() = default;
|
||||
|
|
@ -267,6 +280,8 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const nlohmann::or
|
|||
|
||||
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools);
|
||||
|
||||
common_chat_continuation common_chat_continuation_parse(const nlohmann::ordered_json & value);
|
||||
|
||||
// DEPRECATED: only used in tests
|
||||
nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
|
||||
|
||||
|
|
@ -279,6 +294,10 @@ std::string common_chat_template_direct_apply(
|
|||
const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs);
|
||||
|
||||
std::string common_chat_template_generation_prompt(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs);
|
||||
|
||||
std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
const common_chat_template & tmpl,
|
||||
const std::string & src,
|
||||
|
|
|
|||
|
|
@ -379,7 +379,7 @@ void common_init() {
|
|||
llama_log_set(common_log_default_callback, NULL);
|
||||
}
|
||||
|
||||
void common_params_print_info(const common_params & params) {
|
||||
void common_params_print_info(const common_params & params, bool print_devices) {
|
||||
#ifdef NDEBUG
|
||||
const char * build_type = "";
|
||||
#else
|
||||
|
|
@ -388,12 +388,16 @@ void common_params_print_info(const common_params & params) {
|
|||
LOG_TRC("%s: build %d (%s) with %s for %s%s\n", __func__, llama_build_number(), llama_commit(), llama_compiler(), llama_build_target(), build_type);
|
||||
|
||||
LOG_INF("log_info: verbosity = %d (adjust with the `-lv N` CLI arg)\n", common_log_get_verbosity_thold());
|
||||
LOG_INF("device_info:\n");
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
auto * dev = ggml_backend_dev_get(i);
|
||||
size_t free, total;
|
||||
ggml_backend_dev_memory(dev, &free, &total);
|
||||
LOG_INF(" - %-8s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
|
||||
|
||||
// device enumeration creates a primary context on CUDA backends, skip it when the caller does not own any device
|
||||
if (print_devices) {
|
||||
LOG_INF("device_info:\n");
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
auto * dev = ggml_backend_dev_get(i);
|
||||
size_t free, total;
|
||||
ggml_backend_dev_memory(dev, &free, &total);
|
||||
LOG_INF(" - %-8s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -618,8 +618,6 @@ struct common_params {
|
|||
// UI configs
|
||||
#ifdef LLAMA_UI_DEFAULT_ENABLED
|
||||
bool ui = LLAMA_UI_DEFAULT_ENABLED != 0;
|
||||
#elif defined(LLAMA_WEBUI_DEFAULT_ENABLED)
|
||||
bool ui = LLAMA_WEBUI_DEFAULT_ENABLED != 0;
|
||||
#else
|
||||
bool ui = true; // default to enabled when not set
|
||||
#endif
|
||||
|
|
@ -709,7 +707,7 @@ struct common_params {
|
|||
// initializes the logging system and prints info about the build
|
||||
void common_init();
|
||||
|
||||
void common_params_print_info(const common_params & params);
|
||||
void common_params_print_info(const common_params & params, bool print_devices = true);
|
||||
std::string common_params_get_system_info(const common_params & params);
|
||||
|
||||
bool parse_cpu_range(const std::string & range, bool(&boolmask)[GGML_MAX_N_THREADS]);
|
||||
|
|
|
|||
|
|
@ -471,7 +471,7 @@ void common_ngram_map_draft(common_ngram_map & map,
|
|||
sum_occur += curr_occur;
|
||||
}
|
||||
|
||||
LOG_INF("%s: key_offset = %zu, max_occur = %d, sum_occur = %d, slot_max = %d [%zu/%d, %zu/%d, %zu/%d, %zu/%d]\n", __func__,
|
||||
LOG_DBG("%s: key_offset = %zu, max_occur = %d, sum_occur = %d, slot_max = %d [%zu/%d, %zu/%d, %zu/%d, %zu/%d]\n", __func__,
|
||||
key_offset,
|
||||
max_occur, sum_occur, slot_max,
|
||||
curr_key.values[0].value_idx, curr_key.values[0].value_num,
|
||||
|
|
@ -482,7 +482,7 @@ void common_ngram_map_draft(common_ngram_map & map,
|
|||
// Print the tokens of the four values (if idx != 0), use LOG_INF
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
if (curr_key.values[v].value_idx != 0) {
|
||||
LOG_INF("%s: value[%d] = %s\n", __func__, v, common_tokens_to_str(inp, curr_key.values[v].value_idx, m).c_str());
|
||||
LOG_DBG("%s: value[%d] = %s\n", __func__, v, common_tokens_to_str(inp, curr_key.values[v].value_idx, m).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -146,8 +146,11 @@ struct common_speculative_impl {
|
|||
|
||||
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0;
|
||||
|
||||
// true if this implementation requires the target context to extract embeddings
|
||||
// true if this implementation requires the target context to extract post-norm embeddings
|
||||
virtual bool need_embd() const = 0;
|
||||
|
||||
// true if this implementation requires the target context to extract pre-norm embeddings
|
||||
virtual bool need_embd_pre_norm() const { return false; }
|
||||
};
|
||||
|
||||
struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
|
|
@ -429,8 +432,8 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
|
|||
s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams));
|
||||
}
|
||||
|
||||
llama_set_embeddings_pre_norm(ctx_tgt, true);
|
||||
llama_set_embeddings_pre_norm(ctx_dft, true);
|
||||
llama_set_embeddings_pre_norm(ctx_tgt, true, /*masked*/ false);
|
||||
llama_set_embeddings_pre_norm(ctx_dft, true, /*masked*/ true);
|
||||
|
||||
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
|
||||
|
||||
|
|
@ -691,6 +694,10 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
|
|||
}
|
||||
|
||||
bool need_embd() const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool need_embd_pre_norm() const override {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
|
@ -1408,6 +1415,20 @@ bool common_speculative_need_embd(common_speculative * spec) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool common_speculative_need_embd_pre_norm(common_speculative * spec) {
|
||||
if (spec == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto & impl : spec->impls) {
|
||||
if (impl->need_embd_pre_norm()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void common_speculative_draft(common_speculative * spec) {
|
||||
if (spec == nullptr) {
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -53,9 +53,12 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co
|
|||
// process the batch and update the internal state of the speculative context
|
||||
bool common_speculative_process(common_speculative * spec, const llama_batch & batch);
|
||||
|
||||
// true if any implementation requires target embeddings to be extracted
|
||||
// true if any implementation requires target post-norm embeddings to be extracted
|
||||
bool common_speculative_need_embd(common_speculative * spec);
|
||||
|
||||
// true if any implementation requires target pre-norm embeddings to be extracted
|
||||
bool common_speculative_need_embd_pre_norm(common_speculative * spec);
|
||||
|
||||
// generate drafts for the sequences specified with `common_speculative_get_draft_params`
|
||||
void common_speculative_draft(common_speculative * spec);
|
||||
|
||||
|
|
|
|||
|
|
@ -600,6 +600,7 @@ class _Qwen35MtpMixin:
|
|||
if name.find("layers.") != -1:
|
||||
assert bid is not None
|
||||
name = name.replace(f"mtp.layers.{bid}", f"model.layers.{bid + n_layer}")
|
||||
bid = bid + n_layer
|
||||
else:
|
||||
remapper = {
|
||||
"mtp.fc": "model.layers.{bid}.eh_proj",
|
||||
|
|
|
|||
|
|
@ -140,11 +140,12 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const floa
|
|||
};
|
||||
|
||||
switch (nc) {
|
||||
case 3: launch_kernel(std::integral_constant<int, 3>{}); break;
|
||||
case 4: launch_kernel(std::integral_constant<int, 4>{}); break;
|
||||
case 5: launch_kernel(std::integral_constant<int, 5>{}); break;
|
||||
case 9: launch_kernel(std::integral_constant<int, 9>{}); break;
|
||||
default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9 right now.");
|
||||
case 3: launch_kernel(std::integral_constant<int, 3 >{}); break;
|
||||
case 4: launch_kernel(std::integral_constant<int, 4 >{}); break;
|
||||
case 5: launch_kernel(std::integral_constant<int, 5 >{}); break;
|
||||
case 9: launch_kernel(std::integral_constant<int, 9 >{}); break;
|
||||
case 15: launch_kernel(std::integral_constant<int, 15>{}); break;
|
||||
default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9, 15 right now.");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# include <cub/cub.cuh>
|
||||
# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)
|
||||
# define CUB_TOP_K_AVAILABLE
|
||||
# include <cuda/iterator>
|
||||
using namespace cub;
|
||||
# endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
|
|
|
|||
|
|
@ -50,7 +50,6 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
|
|||
#include <map>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <future>
|
||||
#include <thread>
|
||||
|
|
@ -765,8 +764,8 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_pad_f32;
|
||||
vk_pipeline pipeline_roll_f32;
|
||||
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
|
||||
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32;
|
||||
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32;
|
||||
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_bf16_f32, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32;
|
||||
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_bf16_f32, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32;
|
||||
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_cpy_transpose_16, pipeline_cpy_transpose_32;
|
||||
|
|
@ -860,6 +859,8 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_ssm_scan_f32_d128;
|
||||
vk_pipeline pipeline_ssm_scan_f32_d256;
|
||||
vk_pipeline pipeline_ssm_conv_f32;
|
||||
vk_pipeline pipeline_ssm_conv_silu_f32;
|
||||
vk_pipeline pipeline_ssm_conv_bias_silu_f32;
|
||||
vk_pipeline pipeline_opt_step_adamw_f32;
|
||||
vk_pipeline pipeline_opt_step_sgd_f32;
|
||||
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f32[CONV_SHAPE_COUNT];
|
||||
|
|
@ -1358,6 +1359,8 @@ struct vk_op_rope_push_constants {
|
|||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
uint32_t a_offset;
|
||||
uint32_t d_offset;
|
||||
};
|
||||
static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128");
|
||||
|
||||
|
|
@ -4574,6 +4577,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_bf16_f32,"cpy_bf16_f32",cpy_bf16_f32_len,cpy_bf16_f32_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32", cpy_i32_f32_len, cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32", cpy_f32_i32_len, cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
|
|
@ -4582,6 +4586,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_bf16_f32,"contig_cpy_bf16_f32",contig_cpy_bf16_f32_len,contig_cpy_bf16_f32_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
|
|
@ -4906,7 +4911,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 0}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_silu_f32, "ssm_conv_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 1}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_bias_silu_f32, "ssm_conv_bias_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 1, 1}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
|
|
@ -7566,6 +7573,13 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return ctx->device->pipeline_cpy_f32_bf16;
|
||||
}
|
||||
}
|
||||
if (src->type == GGML_TYPE_BF16 && to == GGML_TYPE_F32) {
|
||||
if (contig) {
|
||||
return ctx->device->pipeline_contig_cpy_bf16_f32;
|
||||
} else {
|
||||
return ctx->device->pipeline_cpy_bf16_f32;
|
||||
}
|
||||
}
|
||||
if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) {
|
||||
if (contig) {
|
||||
return ctx->device->pipeline_contig_cpy_f32_i32;
|
||||
|
|
@ -9964,7 +9978,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return nullptr;
|
||||
case GGML_OP_SSM_CONV:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_ssm_conv_f32;
|
||||
switch (ctx->num_additional_fused_ops) {
|
||||
case 0: return ctx->device->pipeline_ssm_conv_f32;
|
||||
case 1: return ctx->device->pipeline_ssm_conv_silu_f32;
|
||||
case 2: return ctx->device->pipeline_ssm_conv_bias_silu_f32;
|
||||
default: return nullptr;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
|
|
@ -10145,6 +10164,15 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
|
|||
GGML_UNUSED(src3);
|
||||
}
|
||||
|
||||
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_rope_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
|
||||
p.a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
|
||||
p.d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(src2);
|
||||
GGML_UNUSED(src3);
|
||||
}
|
||||
|
||||
template<typename PC>
|
||||
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst, ggml_op op, PC&& pc) {
|
||||
VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
|
||||
|
|
@ -10905,11 +10933,28 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||
pc, elements);
|
||||
}
|
||||
|
||||
static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
||||
ggml_tensor * conv = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * src0 = conv->src[0];
|
||||
const ggml_tensor * src1 = conv->src[1];
|
||||
|
||||
ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SSM_CONV, {
|
||||
// Pick the destination tensor (last node in the fused chain) and the optional bias.
|
||||
// Fusion modes: 0 = ssm_conv, 1 = ssm_conv+silu, 2 = ssm_conv+add(bias)+silu.
|
||||
ggml_tensor * dst = conv;
|
||||
const ggml_tensor * bias = nullptr;
|
||||
|
||||
if (ctx->num_additional_fused_ops == 1) {
|
||||
dst = cgraph->nodes[node_idx + 1]; // silu
|
||||
} else if (ctx->num_additional_fused_ops == 2) {
|
||||
ggml_tensor * add = cgraph->nodes[node_idx + 1];
|
||||
bias = (add->src[0] == conv) ? add->src[1] : add->src[0];
|
||||
dst = cgraph->nodes[node_idx + 2]; // silu
|
||||
}
|
||||
|
||||
// The shader always declares 4 bindings; bind src0 as a dummy when bias isn't fused.
|
||||
const ggml_tensor * src2 = bias ? bias : src0;
|
||||
|
||||
ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SSM_CONV, {
|
||||
(uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
|
||||
(uint32_t)src1->nb[1],
|
||||
(uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
|
||||
|
|
@ -11272,6 +11317,7 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *
|
|||
(uint32_t)src0->ne[2],
|
||||
nb01, nb02, nb03,
|
||||
nb11, nb12, nb13,
|
||||
0, 0, // a_offset, d_offset filled in by init_pushconst_tensor_offsets
|
||||
};
|
||||
|
||||
return rope;
|
||||
|
|
@ -11367,6 +11413,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||
GGML_ASSERT(buf[i] != nullptr);
|
||||
}
|
||||
|
||||
// a_offset is unused (the fused path reads from shared memory), but the rope/set_rows dst can be misaligned.
|
||||
// Round the binding offset down to the storage buffer alignment; the in-element shift goes in pc.rope.d_offset.
|
||||
pc.rope.d_offset = get_misalign_bytes(ctx, tensors[5]) / ggml_type_size(tensors[5]->type);
|
||||
offset[5] &= ~(size_t(ctx->device->properties.limits.minStorageBufferOffsetAlignment) - 1);
|
||||
|
||||
std::array<uint32_t, 3> elements;
|
||||
elements = { (uint32_t)rms->src[0]->ne[1], (uint32_t)rms->src[0]->ne[2], (uint32_t)rms->src[0]->ne[3] };
|
||||
|
||||
|
|
@ -13584,7 +13635,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
break;
|
||||
|
||||
case GGML_OP_SSM_CONV:
|
||||
ggml_vk_ssm_conv(ctx, compute_ctx, node);
|
||||
ggml_vk_ssm_conv(ctx, compute_ctx, cgraph, node_idx);
|
||||
|
||||
break;
|
||||
|
||||
|
|
@ -14481,6 +14532,62 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g
|
|||
return true;
|
||||
}
|
||||
|
||||
// Match SSM_CONV + UNARY(SILU) or SSM_CONV + ADD + UNARY(SILU). num_extra is 1 or 2.
|
||||
static bool ggml_vk_can_fuse_ssm_conv(const ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
|
||||
int node_idx, int num_extra) {
|
||||
const ggml_tensor * conv = cgraph->nodes[node_idx];
|
||||
if (conv->op != GGML_OP_SSM_CONV) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ggml_tensor * silu = nullptr;
|
||||
const ggml_tensor * bias = nullptr;
|
||||
|
||||
if (num_extra == 1) {
|
||||
if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_UNARY })) {
|
||||
return false;
|
||||
}
|
||||
silu = cgraph->nodes[node_idx + 1];
|
||||
} else if (num_extra == 2) {
|
||||
if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY })) {
|
||||
return false;
|
||||
}
|
||||
const ggml_tensor * add = cgraph->nodes[node_idx + 1];
|
||||
silu = cgraph->nodes[node_idx + 2];
|
||||
bias = (add->src[0] == conv) ? add->src[1] : add->src[0];
|
||||
|
||||
if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) {
|
||||
return false;
|
||||
}
|
||||
// bias must be channel-wise (one element per channel of the conv output)
|
||||
if (ggml_nelements(bias) != conv->ne[0] || bias->ne[0] != conv->ne[0]) {
|
||||
return false;
|
||||
}
|
||||
if (add->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
// The shader doesn't apply per-tensor offsets, so reject misaligned bias.
|
||||
if (get_misalign_bytes(ctx, bias) != 0) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ggml_get_unary_op(silu) != GGML_UNARY_OP_SILU) {
|
||||
return false;
|
||||
}
|
||||
if (conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
// The shader writes to the fused dst using its own strides, but the push constants don't
|
||||
// carry a per-tensor offset, so the binding must be naturally aligned.
|
||||
if (get_misalign_bytes(ctx, silu) != 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
|
||||
int node_idx, topk_moe_mode mode) {
|
||||
|
||||
|
|
@ -14897,6 +15004,19 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
// they are overwritten, and one workgroup per row. So close enough.
|
||||
op_srcs_fused_elementwise[0] = true;
|
||||
op_srcs_fused_elementwise[1] = true;
|
||||
} else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 2)) {
|
||||
ctx->num_additional_fused_ops = 2;
|
||||
fusion_string = "SSM_CONV_BIAS_SILU";
|
||||
// ssm_conv reads multiple input tokens per output, so it's not elementwise w.r.t. its srcs.
|
||||
// The downstream add and silu are elementwise on the conv output.
|
||||
op_srcs_fused_elementwise[0] = false;
|
||||
op_srcs_fused_elementwise[1] = true;
|
||||
op_srcs_fused_elementwise[2] = true;
|
||||
} else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 1)) {
|
||||
ctx->num_additional_fused_ops = 1;
|
||||
fusion_string = "SSM_CONV_SILU";
|
||||
op_srcs_fused_elementwise[0] = false;
|
||||
op_srcs_fused_elementwise[1] = true;
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
|
||||
ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
|
||||
ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
|
||||
|
|
@ -15228,7 +15348,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|||
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
|
||||
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&
|
||||
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) &&
|
||||
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) {
|
||||
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD) &&
|
||||
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_ADD) &&
|
||||
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_UNARY)) {
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
|
|
@ -15311,6 +15433,19 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|||
}
|
||||
}
|
||||
}
|
||||
// SSM_CONV + ADD + UNARY: pull the consuming UNARY forward
|
||||
if (j > 0 &&
|
||||
graph->nodes[j]->op == GGML_OP_ADD &&
|
||||
graph->nodes[j-1]->op == GGML_OP_SSM_CONV) {
|
||||
for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
|
||||
if (graph->nodes[k]->op == GGML_OP_UNARY &&
|
||||
graph->nodes[k]->src[0] == graph->nodes[j]) {
|
||||
current_set.push_back(k);
|
||||
used[k] = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Second pass grabs view nodes.
|
||||
|
|
@ -15875,6 +16010,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
if (src1_type == GGML_TYPE_F32) {
|
||||
switch (src0_type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_BF16:
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
|
|
|
|||
|
|
@ -19,7 +19,9 @@ void main() {
|
|||
if (idx + (num_iter-1)*num_threads < p.ne) {
|
||||
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
||||
|
||||
#if defined(DATA_D_BF16)
|
||||
#if defined(DATA_A_BF16)
|
||||
data_d[get_doffset() + idx] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + idx])));
|
||||
#elif defined(DATA_D_BF16)
|
||||
float f = float(data_a[get_aoffset() + idx]);
|
||||
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
|
||||
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
|
||||
|
|
@ -35,7 +37,9 @@ void main() {
|
|||
continue;
|
||||
}
|
||||
|
||||
#if defined(DATA_D_BF16)
|
||||
#if defined(DATA_A_BF16)
|
||||
data_d[get_doffset() + idx] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + idx])));
|
||||
#elif defined(DATA_D_BF16)
|
||||
float f = float(data_a[get_aoffset() + idx]);
|
||||
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
|
||||
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,9 @@ void main() {
|
|||
return;
|
||||
}
|
||||
|
||||
#if defined(DATA_D_BF16)
|
||||
#if defined(DATA_A_BF16)
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + src0_idx(idx)])));
|
||||
#elif defined(DATA_D_BF16)
|
||||
float f = float(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f));
|
||||
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ uint rope_a_coord(const uint i0, const uint i01, const uint i02, const uint i03,
|
|||
// Per-row offset in shared memory
|
||||
const uint ix = i0;
|
||||
#else
|
||||
const uint ix = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0;
|
||||
const uint ix = p.a_offset + i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0;
|
||||
#endif
|
||||
return ix;
|
||||
}
|
||||
|
|
@ -48,6 +48,7 @@ void rope_norm(const uint i0, const uint i1, const uint i2, const uint i3, rope_
|
|||
idst = i1*p.nb11 + i0;
|
||||
idst += rope_data_i[i2].x * p.set_rows_stride;
|
||||
}
|
||||
idst += p.d_offset;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
rope_data_d[idst + 0] = ROPE_D_TYPE(rope_data_a[ix + 0]);
|
||||
|
|
@ -84,6 +85,7 @@ void rope_neox(const uint i0, const uint i1, const uint i2, const uint i3, rope_
|
|||
idst = i1*p.nb11 + i0/2;
|
||||
idst += rope_data_i[i2].x * p.set_rows_stride;
|
||||
}
|
||||
idst += p.d_offset;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);
|
||||
|
|
@ -121,6 +123,7 @@ void rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope
|
|||
idst = i1*p.nb11 + i0/2;
|
||||
idst += rope_data_i[i2].x * p.set_rows_stride;
|
||||
}
|
||||
idst += p.d_offset;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);
|
||||
|
|
@ -176,7 +179,7 @@ void rope_vision(const uint i0, const uint i1, const uint i2, const uint i3, rop
|
|||
return;
|
||||
}
|
||||
|
||||
const uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
|
||||
const uint idst = p.d_offset + i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
|
||||
const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
|
||||
|
||||
const int sect_dims = p.sections[0] + p.sections[1];
|
||||
|
|
|
|||
|
|
@ -26,6 +26,9 @@ struct rope_params {
|
|||
uint nb11;
|
||||
uint nb12;
|
||||
uint nb13;
|
||||
|
||||
uint a_offset;
|
||||
uint d_offset;
|
||||
};
|
||||
|
||||
#endif // !defined(GGML_ROPE_PARAMS)
|
||||
|
|
|
|||
|
|
@ -6,12 +6,15 @@
|
|||
|
||||
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
layout(constant_id = 1) const uint TOKENS_PER_WG = 16;
|
||||
layout(constant_id = 2) const bool APPLY_BIAS = false;
|
||||
layout(constant_id = 3) const bool APPLY_SILU = false;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;
|
||||
|
||||
layout(binding = 0) readonly buffer Src0 { float src0[]; };
|
||||
layout(binding = 1) readonly buffer Src1 { float src1[]; };
|
||||
layout(binding = 2) buffer Dst { float dst[]; };
|
||||
layout(binding = 2) readonly buffer Bias { float bias[]; };
|
||||
layout(binding = 3) buffer Dst { float dst[]; };
|
||||
|
||||
layout(push_constant) uniform PushConstants {
|
||||
uint nb01; uint nb02;
|
||||
|
|
@ -45,6 +48,13 @@ void main() {
|
|||
}
|
||||
}
|
||||
|
||||
if (APPLY_BIAS) {
|
||||
sum += bias[i1];
|
||||
}
|
||||
if (APPLY_SILU) {
|
||||
sum = sum / (1.0f + exp(-sum));
|
||||
}
|
||||
|
||||
const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
|
||||
dst[dst_idx] = sum;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -748,6 +748,7 @@ void process_shaders() {
|
|||
string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
||||
string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
||||
string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
|
||||
string_to_spv("cpy_bf16_f32","copy.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "float"}, {"DATA_A_BF16", "1"}});
|
||||
string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("contig_cpy_f32_i32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}});
|
||||
string_to_spv("contig_cpy_i32_f32", "contig_copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}});
|
||||
|
|
@ -755,6 +756,7 @@ void process_shaders() {
|
|||
string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
||||
string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
||||
string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
|
||||
string_to_spv("contig_cpy_bf16_f32","contig_copy.comp",{{"A_TYPE", "uint16_t"}, {"D_TYPE", "float"}, {"DATA_A_BF16", "1"}});
|
||||
string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}});
|
||||
string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}});
|
||||
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ dry_seq_break_max = 128
|
|||
extra_images_max = 4 # for kontext/qwen img
|
||||
|
||||
# global vars
|
||||
KcppVersion = "1.113.1"
|
||||
KcppVersion = "1.114"
|
||||
showdebug = True
|
||||
kcpp_instance = None #global running instance
|
||||
global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False, "restart_override_base_config":"", "last_active_timestamp":datetime.now(), "triggered_sleeping":False, "current_model":"initial_model", "base_config":"", "swapReqType": None, "autoswapmode": False}
|
||||
|
|
|
|||
|
|
@ -67,8 +67,9 @@ llama_context::llama_context(
|
|||
cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
|
||||
cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
|
||||
cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
|
||||
cparams.embeddings = params.embeddings;
|
||||
cparams.embeddings_pre_norm = false;
|
||||
cparams.embeddings = params.embeddings;
|
||||
cparams.embeddings_pre_norm = false;
|
||||
cparams.embeddings_pre_norm_masked = false;
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.no_perf = params.no_perf;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
|
|
@ -905,8 +906,17 @@ float * llama_context::get_embeddings_pre_norm_ith(int32_t i) {
|
|||
throw std::runtime_error("no pre-norm embeddings");
|
||||
}
|
||||
|
||||
const int64_t j = output_resolve_row(i);
|
||||
const uint32_t n_embd = model.hparams.n_embd;
|
||||
|
||||
if (!cparams.embeddings_pre_norm_masked) {
|
||||
// unmasked: pre-norm rows are stored densely, indexed by raw token position.
|
||||
if (i < 0 || (size_t)(i + 1) * n_embd > embd_pre_norm.size) {
|
||||
throw std::runtime_error(format("out of range [0, %zu)", embd_pre_norm.size / n_embd));
|
||||
}
|
||||
return embd_pre_norm.data + (size_t) i * n_embd;
|
||||
}
|
||||
|
||||
const int64_t j = output_resolve_row(i);
|
||||
return embd_pre_norm.data + j*n_embd;
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what());
|
||||
|
|
@ -1098,10 +1108,11 @@ void llama_context::set_embeddings(bool value) {
|
|||
//sched_need_reserve = true;
|
||||
}
|
||||
|
||||
void llama_context::set_embeddings_pre_norm(bool value) {
|
||||
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
||||
void llama_context::set_embeddings_pre_norm(bool value, bool masked) {
|
||||
LLAMA_LOG_DEBUG("%s: value = %d, masked = %d\n", __func__, value, masked);
|
||||
|
||||
cparams.embeddings_pre_norm = value;
|
||||
cparams.embeddings_pre_norm = value;
|
||||
cparams.embeddings_pre_norm_masked = masked;
|
||||
}
|
||||
|
||||
void llama_context::set_causal_attn(bool value) {
|
||||
|
|
@ -1747,6 +1758,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
};
|
||||
|
||||
int64_t n_outputs_prev = 0;
|
||||
int64_t n_tokens_prev = 0;
|
||||
|
||||
do {
|
||||
const auto & ubatch = mctx->get_ubatch();
|
||||
|
|
@ -1892,16 +1904,21 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
|
||||
// extract pre-norm embeddings (hidden state before the final output norm)
|
||||
// only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored.
|
||||
if (embd_pre_norm.data && t_h_pre_norm && n_outputs > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
|
||||
GGML_ASSERT(backend_h != nullptr);
|
||||
{
|
||||
const bool masked = cparams.embeddings_pre_norm_masked;
|
||||
const int64_t n_rows = masked ? n_outputs : (int64_t) ubatch.n_tokens;
|
||||
const int64_t offset = masked ? n_outputs_prev : n_tokens_prev;
|
||||
|
||||
const uint32_t n_embd = hparams.n_embd;
|
||||
float * embd_pre_norm_out = embd_pre_norm.data + n_outputs_prev*n_embd;
|
||||
if (embd_pre_norm.data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
|
||||
GGML_ASSERT(backend_h != nullptr);
|
||||
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_pre_norm.size);
|
||||
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_outputs*n_embd*sizeof(float));
|
||||
const uint32_t n_embd = hparams.n_embd;
|
||||
float * embd_pre_norm_out = embd_pre_norm.data + offset*n_embd;
|
||||
|
||||
GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_pre_norm.size);
|
||||
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_rows*n_embd*sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// Copy backend sampling output if this ubatch produced any sampling tensors.
|
||||
|
|
@ -1918,6 +1935,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
}
|
||||
|
||||
n_outputs_prev += n_outputs;
|
||||
n_tokens_prev += ubatch.n_tokens;
|
||||
} while (mctx->next());
|
||||
|
||||
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
||||
|
|
@ -2009,6 +2027,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|||
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
|
||||
embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0;
|
||||
|
||||
if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked) {
|
||||
// unmasked: pre-norm row exists for every token in the batch, not just
|
||||
// those flagged via batch.logits[i] -> size by token count instead.
|
||||
embd_pre_norm.size = (size_t) n_embd * n_batch;
|
||||
}
|
||||
|
||||
// Allocate backend sampling output buffers if there are backend samplers configured.
|
||||
const bool has_sampling = !sampling.samplers.empty();
|
||||
if (has_sampling) {
|
||||
|
|
@ -3557,8 +3581,8 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
|
|||
return ctx->get_embeddings_seq(seq_id);
|
||||
}
|
||||
|
||||
void llama_set_embeddings_pre_norm(llama_context * ctx, bool value) {
|
||||
ctx->set_embeddings_pre_norm(value);
|
||||
void llama_set_embeddings_pre_norm(llama_context * ctx, bool value, bool masked) {
|
||||
ctx->set_embeddings_pre_norm(value, masked);
|
||||
}
|
||||
|
||||
float * llama_get_embeddings_pre_norm(llama_context * ctx) {
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ struct llama_context {
|
|||
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);
|
||||
|
||||
void set_embeddings (bool value);
|
||||
void set_embeddings_pre_norm(bool value);
|
||||
void set_embeddings_pre_norm(bool value, bool masked);
|
||||
void set_causal_attn(bool value);
|
||||
void set_warmup(bool value);
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,8 @@ struct llama_cparams {
|
|||
float yarn_beta_slow;
|
||||
|
||||
bool embeddings;
|
||||
bool embeddings_pre_norm; // also extract the hidden state before the final output norm
|
||||
bool embeddings_pre_norm; // also extract the hidden state before the final output norm
|
||||
bool embeddings_pre_norm_masked; // extract for only rows where batch.logits != 0
|
||||
bool causal_attn;
|
||||
bool offload_kqv;
|
||||
bool flash_attn;
|
||||
|
|
|
|||
|
|
@ -93,14 +93,14 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c
|
|||
// pre-norm embeddings (hidden state before the final output norm)
|
||||
//
|
||||
|
||||
// mirrors:
|
||||
// LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
|
||||
LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value);
|
||||
// Set whether the context outputs pre-norm embeddings or not
|
||||
// If masked == true, output the embeddings only for the tokens with batch.logits != 0
|
||||
// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits
|
||||
LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value, bool masked);
|
||||
|
||||
// mirrors:
|
||||
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||
LLAMA_API float * llama_get_embeddings_pre_norm(struct llama_context * ctx);
|
||||
LLAMA_API float * llama_get_embeddings_pre_norm (struct llama_context * ctx);
|
||||
|
||||
// mirrors:
|
||||
// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
||||
LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i);
|
||||
|
|
|
|||
|
|
@ -848,6 +848,9 @@ void llm_graph_result::set_outputs() {
|
|||
if (t_embd_pooled != nullptr) {
|
||||
ggml_set_output(t_embd_pooled);
|
||||
}
|
||||
if (t_h_pre_norm != nullptr) {
|
||||
ggml_set_output(t_h_pre_norm);
|
||||
}
|
||||
for (auto & [seq_id, t] : t_sampled) {
|
||||
if (t != nullptr) {
|
||||
ggml_set_output(t);
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
|
|||
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
|
||||
}
|
||||
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids) {
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
|
@ -211,6 +211,10 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
|
|||
cb(cur, "h_pre_norm", -1);
|
||||
res->t_h_pre_norm = cur;
|
||||
|
||||
if (!cparams.embeddings_pre_norm_masked && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
}
|
||||
|
||||
// Final norm
|
||||
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
|
||||
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
|
|||
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
|
||||
}
|
||||
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids) {
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
|
@ -234,6 +234,10 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
|
|||
cb(cur, "h_pre_norm", -1);
|
||||
res->t_h_pre_norm = cur;
|
||||
|
||||
if (!cparams.embeddings_pre_norm_masked && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
}
|
||||
|
||||
// Final norm
|
||||
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
|
||||
|
||||
|
|
|
|||
|
|
@ -1032,23 +1032,33 @@ json oaicompat_chat_params_parse(
|
|||
auto caps = common_chat_templates_get_caps(opt.tmpls.get());
|
||||
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
|
||||
inputs.tools = common_chat_tools_parse_oaicompat(tools);
|
||||
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice);
|
||||
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
|
||||
inputs.grammar = grammar;
|
||||
inputs.use_jinja = opt.use_jinja;
|
||||
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", caps["supports_parallel_tool_calls"]);
|
||||
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
|
||||
const bool continue_final_message = json_value(body, "continue_final_message", false);
|
||||
if (continue_final_message && inputs.add_generation_prompt) {
|
||||
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
|
||||
inputs.tools = common_chat_tools_parse_oaicompat(tools);
|
||||
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice);
|
||||
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
|
||||
inputs.grammar = grammar;
|
||||
inputs.use_jinja = opt.use_jinja;
|
||||
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", caps["supports_parallel_tool_calls"]);
|
||||
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
|
||||
inputs.continue_final_message = body.contains("continue_final_message") ?
|
||||
common_chat_continuation_parse(body.at("continue_final_message")) :
|
||||
COMMON_CHAT_CONTINUATION_NONE;
|
||||
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_NONE && opt.prefill_assistant
|
||||
&& !inputs.messages.empty() && inputs.messages.back().role == "assistant") {
|
||||
if (inputs.messages.size() >= 2 && inputs.messages[inputs.messages.size() - 2].role == "assistant") {
|
||||
throw std::invalid_argument("Cannot have 2 or more assistant messages at the end of the list.");
|
||||
}
|
||||
inputs.continue_final_message = COMMON_CHAT_CONTINUATION_AUTO;
|
||||
inputs.add_generation_prompt = false;
|
||||
}
|
||||
if (inputs.continue_final_message != COMMON_CHAT_CONTINUATION_NONE && inputs.add_generation_prompt) {
|
||||
throw std::invalid_argument("Cannot set both add_generation_prompt and continue_final_message to true.");
|
||||
}
|
||||
inputs.reasoning_format = opt.reasoning_format;
|
||||
inputs.reasoning_format = opt.reasoning_format;
|
||||
if (body.contains("reasoning_format")) {
|
||||
inputs.reasoning_format = common_reasoning_format_from_name(body.at("reasoning_format").get<std::string>());
|
||||
}
|
||||
inputs.enable_thinking = opt.enable_thinking;
|
||||
inputs.enable_thinking = opt.enable_thinking;
|
||||
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
if (body.contains("grammar")) {
|
||||
throw std::invalid_argument("Cannot use custom grammar constraints with tools.");
|
||||
|
|
@ -1073,84 +1083,11 @@ json oaicompat_chat_params_parse(
|
|||
throw std::invalid_argument("invalid type for \"enable_thinking\" (expected boolean, got string)");
|
||||
}
|
||||
|
||||
// if the assistant message appears at the end of list, we do not add end-of-turn token
|
||||
// for ex. this can be useful to modify the reasoning process in reasoning models
|
||||
// continue_final_message is the explicit opt in alias from the vLLM/transformers API,
|
||||
// equivalent to the prefill_assistant heuristic
|
||||
bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant"
|
||||
&& (continue_final_message || opt.prefill_assistant);
|
||||
common_chat_msg last_message;
|
||||
if (prefill_assistant_message) {
|
||||
last_message = inputs.messages.back();
|
||||
inputs.messages.pop_back();
|
||||
|
||||
/* sanity check, max one assistant message at the end of the list */
|
||||
if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){
|
||||
throw std::invalid_argument("Cannot have 2 or more assistant messages at the end of the list.");
|
||||
}
|
||||
|
||||
// reject reasoning prefill on channel based templates that do not expose explicit thinking tags
|
||||
if (!last_message.reasoning_content.empty() && inputs.enable_thinking) {
|
||||
auto probe_params = common_chat_templates_apply(opt.tmpls.get(), inputs);
|
||||
if (probe_params.supports_thinking && probe_params.thinking_end_tag.empty()) {
|
||||
throw std::invalid_argument("Assistant prefill with reasoning_content is not supported yet for this template.");
|
||||
}
|
||||
}
|
||||
|
||||
inputs.add_generation_prompt = true;
|
||||
}
|
||||
inputs.force_pure_content = opt.force_pure_content;
|
||||
|
||||
// Apply chat template to the list of messages
|
||||
auto chat_params = common_chat_templates_apply(opt.tmpls.get(), inputs);
|
||||
|
||||
/* Append assistant prefilled message */
|
||||
if (prefill_assistant_message) {
|
||||
const bool thinking_active = chat_params.supports_thinking && !chat_params.thinking_end_tag.empty();
|
||||
const bool has_reasoning = !last_message.reasoning_content.empty();
|
||||
const bool has_content = !last_message.content.empty() || !last_message.content_parts.empty();
|
||||
const bool mid_reasoning = has_reasoning && !has_content;
|
||||
|
||||
// some templates inject thinking_start in generation_prompt, others let the model emit it
|
||||
const bool gp_has_think = thinking_active
|
||||
&& chat_params.generation_prompt.find(chat_params.thinking_start_tag) != std::string::npos;
|
||||
|
||||
// open the thinking block when reasoning is present and the template did not inject it
|
||||
if (has_reasoning) {
|
||||
if (thinking_active && !gp_has_think) {
|
||||
chat_params.prompt += chat_params.thinking_start_tag;
|
||||
}
|
||||
chat_params.prompt += last_message.reasoning_content;
|
||||
}
|
||||
|
||||
if (thinking_active) {
|
||||
if (mid_reasoning) {
|
||||
// model continues inside the thinking block, keep generation_prompt open on think
|
||||
if (!gp_has_think) {
|
||||
chat_params.generation_prompt += chat_params.thinking_start_tag;
|
||||
}
|
||||
} else {
|
||||
// close thinking block when reasoning is followed by content, or when the template forced it open
|
||||
if (has_reasoning || gp_has_think) {
|
||||
chat_params.prompt += chat_params.thinking_end_tag;
|
||||
}
|
||||
// strip thinking_start from generation_prompt so the parser routes model output as content
|
||||
auto pos = chat_params.generation_prompt.rfind(chat_params.thinking_start_tag);
|
||||
if (pos != std::string::npos) {
|
||||
chat_params.generation_prompt = chat_params.generation_prompt.substr(0, pos);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!last_message.content_parts.empty()) {
|
||||
for (auto & p : last_message.content_parts) {
|
||||
chat_params.prompt += p.text;
|
||||
}
|
||||
} else {
|
||||
chat_params.prompt += last_message.content;
|
||||
}
|
||||
}
|
||||
|
||||
llama_params["chat_format"] = static_cast<int>(chat_params.format);
|
||||
llama_params["prompt"] = chat_params.prompt;
|
||||
if (!chat_params.grammar.empty()) {
|
||||
|
|
|
|||
|
|
@ -243,6 +243,11 @@ struct server_slot {
|
|||
return task->need_embd() || (spec && common_speculative_need_embd(spec));
|
||||
}
|
||||
|
||||
bool need_embd_pre_norm() const {
|
||||
GGML_ASSERT(task);
|
||||
return spec && common_speculative_need_embd_pre_norm(spec);
|
||||
}
|
||||
|
||||
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
|
||||
// also we cannot split if the pooling would require any past tokens
|
||||
// (MTP supports splitting — uses task->need_embd() not need_embd())
|
||||
|
|
@ -4527,7 +4532,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_embeddings_impl(cons
|
|||
}
|
||||
}
|
||||
|
||||
int embd_normalize = 2; // default to Euclidean/L2 norm
|
||||
int embd_normalize = params.embd_normalize;
|
||||
if (body.count("embd_normalize") != 0) {
|
||||
embd_normalize = body.at("embd_normalize");
|
||||
if (meta->pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
|
|
|
|||
|
|
@ -231,11 +231,10 @@ bool server_http_context::init(const common_params & params) {
|
|||
};
|
||||
|
||||
auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) {
|
||||
(void)req; // suppress unused parameter warning when LLAMA_BUILD_UI / LLAMA_BUILD_WEBUI is not defined
|
||||
(void)req; // suppress unused parameter warning when LLAMA_BUILD_UI is not defined
|
||||
bool ready = is_ready.load();
|
||||
if (!ready) {
|
||||
// Support both old and new preprocessor defines
|
||||
#if defined(LLAMA_BUILD_UI) || defined(LLAMA_BUILD_WEBUI)
|
||||
#if defined(LLAMA_BUILD_UI)
|
||||
auto tmp = string_split<std::string>(req.path, '.');
|
||||
if (req.path == "/" || (tmp.size() > 0 && tmp.back() == "html")) {
|
||||
res.status = 503;
|
||||
|
|
@ -313,8 +312,7 @@ bool server_http_context::init(const common_params & params) {
|
|||
return 1;
|
||||
}
|
||||
} else {
|
||||
// Support both old and new preprocessor defines
|
||||
#if defined(LLAMA_BUILD_UI) || defined(LLAMA_BUILD_WEBUI)
|
||||
#if defined(LLAMA_BUILD_UI)
|
||||
// using embedded static index.html
|
||||
srv->Get(params.api_prefix + "/", [](const httplib::Request & /*req*/, httplib::Response & res) {
|
||||
// COEP and COOP headers, required by pyodide (python interpreter)
|
||||
|
|
|
|||
|
|
@ -798,9 +798,10 @@ void server_models::load(const std::string & name) {
|
|||
std::thread log_thread([&]() {
|
||||
// read stdout/stderr and forward to main server log
|
||||
// also handle status report from child process
|
||||
std::vector<char> vec_buf(128 * 1024); // large buffer for storing info
|
||||
char * buffer = vec_buf.data();
|
||||
if (stdout_file) {
|
||||
char buffer[128 * 1024]; // large buffer for storing info
|
||||
while (fgets(buffer, sizeof(buffer), stdout_file) != nullptr) {
|
||||
while (fgets(buffer, vec_buf.size(), stdout_file) != nullptr) {
|
||||
LOG("[%5d] %s", port, buffer);
|
||||
std::string str(buffer);
|
||||
if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_READY)) {
|
||||
|
|
|
|||
|
|
@ -144,6 +144,17 @@ json task_params::to_json(bool only_metrics) const {
|
|||
//
|
||||
// task_result_state
|
||||
//
|
||||
task_result_state::task_result_state(const common_chat_parser_params & chat_parser_params)
|
||||
: chat_parser_params(chat_parser_params)
|
||||
, oai_resp_id("resp_" + random_string())
|
||||
, oai_resp_reasoning_id("rs_" + random_string())
|
||||
, oai_resp_message_id("msg_" + random_string()) {
|
||||
if (!chat_parser_params.echo) {
|
||||
// initialize chat_msg to avoid emitting a delta containing the assistant prefill
|
||||
chat_msg = common_chat_parse("", true, chat_parser_params);
|
||||
}
|
||||
}
|
||||
|
||||
common_chat_msg task_result_state::update_chat_msg(
|
||||
const std::string & text_added,
|
||||
bool is_partial,
|
||||
|
|
@ -421,6 +432,7 @@ task_params server_task::params_from_json_cmpl(
|
|||
if (data.contains("chat_parser")) {
|
||||
params.chat_parser_params.parser.load(data.at("chat_parser").get<std::string>());
|
||||
}
|
||||
params.chat_parser_params.echo = json_value(data, "echo", false);
|
||||
}
|
||||
|
||||
{
|
||||
|
|
|
|||
|
|
@ -112,11 +112,7 @@ struct task_result_state {
|
|||
const std::string oai_resp_message_id;
|
||||
std::string oai_resp_fc_id; // function call ID for current args delta
|
||||
|
||||
task_result_state(const common_chat_parser_params & chat_parser_params)
|
||||
: chat_parser_params(chat_parser_params)
|
||||
, oai_resp_id("resp_" + random_string())
|
||||
, oai_resp_reasoning_id("rs_" + random_string())
|
||||
, oai_resp_message_id("msg_" + random_string()) {}
|
||||
task_result_state(const common_chat_parser_params & chat_parser_params);
|
||||
|
||||
// parse partial tool calls and update the internal state
|
||||
common_chat_msg update_chat_msg(
|
||||
|
|
|
|||
|
|
@ -86,7 +86,10 @@ int main(int argc, char ** argv) {
|
|||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
common_params_print_info(params);
|
||||
// router server never loads a model and must not touch the GPU
|
||||
// skip device enumeration so the CUDA primary context stays uncreated
|
||||
const bool is_router_server = params.model.path.empty();
|
||||
common_params_print_info(params, !is_router_server);
|
||||
|
||||
// validate batch size for embeddings
|
||||
// embeddings require all tokens to be processed in a single ubatch
|
||||
|
|
@ -126,7 +129,6 @@ int main(int argc, char ** argv) {
|
|||
server_routes routes(params, ctx_server);
|
||||
server_tools tools;
|
||||
|
||||
bool is_router_server = params.model.path.empty();
|
||||
std::optional<server_models_routes> models_routes{};
|
||||
if (is_router_server) {
|
||||
// setup server instances manager
|
||||
|
|
|
|||
|
|
@ -158,11 +158,12 @@ def test_chat_template():
|
|||
|
||||
@pytest.mark.parametrize("prefill,re_prefill", [
|
||||
("Whill", "Whill"),
|
||||
([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"),
|
||||
([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Wh\n\nill"),
|
||||
])
|
||||
def test_chat_template_assistant_prefill(prefill, re_prefill):
|
||||
global server
|
||||
server.chat_template = "llama3"
|
||||
server.jinja = True
|
||||
server.chat_template_file = "../../../models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"
|
||||
server.debug = True # to get the "__verbose" object in the response
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
|
|
@ -175,14 +176,15 @@ def test_chat_template_assistant_prefill(prefill, re_prefill):
|
|||
})
|
||||
assert res.status_code == 200
|
||||
assert "__verbose" in res.body
|
||||
assert res.body["__verbose"]["prompt"] == f"<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}"
|
||||
assert res.body["__verbose"]["prompt"].endswith(f"<|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}")
|
||||
|
||||
|
||||
def test_chat_template_continue_final_message_vllm_compat():
|
||||
"""continue_final_message is the vLLM/transformers explicit alias for the prefill_assistant heuristic.
|
||||
Both must produce the same prompt."""
|
||||
global server
|
||||
server.chat_template = "llama3"
|
||||
server.jinja = True
|
||||
server.chat_template_file = "../../../models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"
|
||||
server.debug = True
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
|
|
@ -197,7 +199,7 @@ def test_chat_template_continue_final_message_vllm_compat():
|
|||
})
|
||||
assert res.status_code == 200
|
||||
assert "__verbose" in res.body
|
||||
assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nWhill"
|
||||
assert res.body["__verbose"]["prompt"].endswith("<|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nWhill")
|
||||
|
||||
|
||||
def test_chat_template_continue_final_message_mutual_exclusion():
|
||||
|
|
|
|||
|
|
@ -14,12 +14,7 @@ endif()
|
|||
set(TARGET_SRCS "")
|
||||
set(UI_COMPILE_DEFS "")
|
||||
|
||||
# Support both old (LLAMA_BUILD_WEBUI) and new (LLAMA_BUILD_UI) option names
|
||||
if(LLAMA_BUILD_WEBUI OR LLAMA_BUILD_UI)
|
||||
if(LLAMA_BUILD_WEBUI AND NOT LLAMA_BUILD_UI)
|
||||
message(DEPRECATION "LLAMA_BUILD_WEBUI is deprecated, use LLAMA_BUILD_UI instead")
|
||||
endif()
|
||||
|
||||
if(LLAMA_BUILD_UI)
|
||||
set(PUBLIC_ASSETS
|
||||
index.html
|
||||
bundle.js
|
||||
|
|
@ -125,19 +120,17 @@ if(LLAMA_BUILD_WEBUI OR LLAMA_BUILD_UI)
|
|||
endforeach()
|
||||
|
||||
list(APPEND UI_COMPILE_DEFS
|
||||
LLAMA_BUILD_WEBUI # Deprecated: use LLAMA_BUILD_UI
|
||||
LLAMA_BUILD_UI
|
||||
LLAMA_WEBUI_DEFAULT_ENABLED=1 # Deprecated: use LLAMA_UI_DEFAULT_ENABLED
|
||||
LLAMA_UI_DEFAULT_ENABLED=1
|
||||
)
|
||||
message(STATUS "UI: embedded with source: ${UI_SOURCE}")
|
||||
else()
|
||||
message(WARNING "UI: no source available. Neither local build (build/tools/ui/dist/) nor HF Bucket download succeeded.")
|
||||
message(WARNING "UI: building server without embedded UI. Set LLAMA_BUILD_UI=OFF to suppress this warning.")
|
||||
list(APPEND UI_COMPILE_DEFS LLAMA_WEBUI_DEFAULT_ENABLED=0 LLAMA_UI_DEFAULT_ENABLED=0)
|
||||
list(APPEND UI_COMPILE_DEFS LLAMA_UI_DEFAULT_ENABLED=0)
|
||||
endif()
|
||||
else()
|
||||
list(APPEND UI_COMPILE_DEFS LLAMA_WEBUI_DEFAULT_ENABLED=0 LLAMA_UI_DEFAULT_ENABLED=0)
|
||||
list(APPEND UI_COMPILE_DEFS LLAMA_UI_DEFAULT_ENABLED=0)
|
||||
endif()
|
||||
|
||||
# Build the static library
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
@import 'tailwindcss';
|
||||
@source ".";
|
||||
|
||||
@import 'tw-animate-css';
|
||||
|
||||
|
|
@ -39,6 +40,9 @@
|
|||
--sidebar-ring: oklch(0.708 0 0);
|
||||
--code-background: oklch(0.985 0 0);
|
||||
--code-foreground: oklch(0.145 0 0);
|
||||
--font-mono:
|
||||
ui-monospace, SFMono-Regular, 'SF Mono', Monaco, 'Cascadia Code', 'Roboto Mono', Consolas,
|
||||
'Liberation Mono', Menlo, monospace;
|
||||
--layer-popover: 1000000;
|
||||
|
||||
--chat-form-area-height: 8rem;
|
||||
|
|
@ -171,6 +175,10 @@
|
|||
*::-webkit-scrollbar-thumb:hover {
|
||||
background: hsl(var(--muted-foreground) / 0.5);
|
||||
}
|
||||
|
||||
:where(code, pre, kbd, samp) {
|
||||
font-family: var(--font-mono);
|
||||
}
|
||||
}
|
||||
|
||||
@layer utilities {
|
||||
|
|
|
|||
2
tools/ui/src/app.d.ts
vendored
2
tools/ui/src/app.d.ts
vendored
|
|
@ -39,6 +39,7 @@ import type {
|
|||
DatabaseMessage,
|
||||
DatabaseMessageExtra,
|
||||
DatabaseMessageExtraAudioFile,
|
||||
DatabaseMessageExtraVideoFile,
|
||||
DatabaseMessageExtraImageFile,
|
||||
DatabaseMessageExtraTextFile,
|
||||
DatabaseMessageExtraPdfFile,
|
||||
|
|
@ -102,6 +103,7 @@ declare global {
|
|||
DatabaseMessage,
|
||||
DatabaseMessageExtra,
|
||||
DatabaseMessageExtraAudioFile,
|
||||
DatabaseMessageExtraVideoFile,
|
||||
DatabaseMessageExtraImageFile,
|
||||
DatabaseMessageExtraTextFile,
|
||||
DatabaseMessageExtraPdfFile,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
<script lang="ts">
|
||||
import { Eye, Mic } from '@lucide/svelte';
|
||||
import { Eye, Mic, Video } from '@lucide/svelte';
|
||||
import { ModelModality } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
|
|
@ -11,7 +11,7 @@
|
|||
</script>
|
||||
|
||||
{#each modalities as modality (modality)}
|
||||
{#if modality === ModelModality.VISION || modality === ModelModality.AUDIO}
|
||||
{#if modality === ModelModality.VISION || modality === ModelModality.AUDIO || modality === ModelModality.VIDEO}
|
||||
<span
|
||||
class={[
|
||||
'inline-flex items-center gap-1 rounded-md bg-muted px-2 py-1 text-xs font-medium',
|
||||
|
|
@ -21,7 +21,11 @@
|
|||
{#if modality === ModelModality.VISION}
|
||||
<Eye class="h-3 w-3" />
|
||||
|
||||
Vision
|
||||
Vision (Image)
|
||||
{:else if modality === ModelModality.VIDEO}
|
||||
<Video class="h-3 w-3" />
|
||||
|
||||
Vision (Video)
|
||||
{:else}
|
||||
<Mic class="h-3 w-3" />
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
<script lang="ts">
|
||||
import { X } from '@lucide/svelte';
|
||||
import { X, Music, Video } from '@lucide/svelte';
|
||||
import {
|
||||
formatFileSize,
|
||||
getFileTypeLabel,
|
||||
getPreviewText,
|
||||
isPdfFile,
|
||||
isAudioFile,
|
||||
isVideoFile,
|
||||
isTextFile
|
||||
} from '$lib/utils';
|
||||
import { ActionIcon } from '$lib/components/app';
|
||||
|
|
@ -38,6 +40,8 @@
|
|||
}: Props = $props();
|
||||
|
||||
let isPdf = $derived(isPdfFile(attachment, uploadedFile));
|
||||
let isAudio = $derived(isAudioFile(attachment, uploadedFile));
|
||||
let isVideo = $derived(isVideoFile(attachment, uploadedFile));
|
||||
let isPdfWithContent = $derived(isPdf && !!textContent);
|
||||
|
||||
let isText = $derived(isTextFile(attachment, uploadedFile));
|
||||
|
|
@ -102,7 +106,13 @@
|
|||
<div
|
||||
class="flex h-8 w-8 items-center justify-center rounded bg-primary/10 text-xs font-medium text-primary"
|
||||
>
|
||||
{fileTypeLabel}
|
||||
{#if isAudio}
|
||||
<Music class="h-4 w-4 text-white/70" />
|
||||
{:else if isVideo}
|
||||
<Video class="h-4 w-4 text-white/70" />
|
||||
{:else}
|
||||
{fileTypeLabel}
|
||||
{/if}
|
||||
</div>
|
||||
{/snippet}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
getAttachmentDisplayItems,
|
||||
getLanguageFromFilename,
|
||||
isAudioFile,
|
||||
isVideoFile,
|
||||
isImageFile,
|
||||
isMcpPrompt,
|
||||
isMcpResource,
|
||||
|
|
@ -29,6 +30,7 @@
|
|||
textContent?: string;
|
||||
isImage: boolean;
|
||||
isAudio: boolean;
|
||||
isVideo: boolean;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
|
|
@ -54,7 +56,8 @@
|
|||
(item): PreviewItem => ({
|
||||
...item,
|
||||
isImage: isImageFile(item.attachment, item.uploadedFile),
|
||||
isAudio: isAudioFile(item.attachment, item.uploadedFile)
|
||||
isAudio: isAudioFile(item.attachment, item.uploadedFile),
|
||||
isVideo: isVideoFile(item.attachment, item.uploadedFile)
|
||||
})
|
||||
)
|
||||
);
|
||||
|
|
@ -102,6 +105,9 @@
|
|||
let isAudio = $derived(
|
||||
currentItem ? isAudioFile(currentItem.attachment, currentItem.uploadedFile) : false
|
||||
);
|
||||
let isVideo = $derived(
|
||||
currentItem ? isVideoFile(currentItem.attachment, currentItem.uploadedFile) : false
|
||||
);
|
||||
let isImage = $derived(
|
||||
currentItem ? isImageFile(currentItem.attachment, currentItem.uploadedFile) : false
|
||||
);
|
||||
|
|
@ -148,6 +154,20 @@
|
|||
: null
|
||||
);
|
||||
|
||||
let videoSrc = $derived(
|
||||
isVideo && currentItem
|
||||
? (currentItem.uploadedFile?.preview ??
|
||||
(currentItem.attachment &&
|
||||
'mimeType' in currentItem.attachment &&
|
||||
'base64Data' in currentItem.attachment
|
||||
? createBase64DataUrl(
|
||||
currentItem.attachment.mimeType,
|
||||
currentItem.attachment.base64Data
|
||||
)
|
||||
: null))
|
||||
: null
|
||||
);
|
||||
|
||||
export function prev() {
|
||||
currentIndex = currentIndex > 0 ? currentIndex - 1 : allItems.length - 1;
|
||||
}
|
||||
|
|
@ -173,11 +193,13 @@
|
|||
{currentItem}
|
||||
{isImage}
|
||||
{isAudio}
|
||||
{isVideo}
|
||||
{isPdf}
|
||||
{isText}
|
||||
{displayPreview}
|
||||
{displayTextContent}
|
||||
{audioSrc}
|
||||
{videoSrc}
|
||||
{language}
|
||||
{hasVisionModality}
|
||||
{activeModelId}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
<script lang="ts">
|
||||
import type { ChatAttachmentDisplayItem } from '$lib/types';
|
||||
import { Image, Music, FileText, FileIcon } from '@lucide/svelte';
|
||||
import { Image, Music, Video, FileText, FileIcon } from '@lucide/svelte';
|
||||
import ChatAttachmentsPreviewCurrentItemPdf from './ChatAttachmentsPreviewCurrentItemPdf.svelte';
|
||||
import ChatAttachmentsPreviewCurrentItemImage from './ChatAttachmentsPreviewCurrentItemImage.svelte';
|
||||
import ChatAttachmentsPreviewCurrentItemAudio from './ChatAttachmentsPreviewCurrentItemAudio.svelte';
|
||||
import ChatAttachmentsPreviewCurrentItemVideo from './ChatAttachmentsPreviewCurrentItemVideo.svelte';
|
||||
import ChatAttachmentsPreviewCurrentItemText from './ChatAttachmentsPreviewCurrentItemText.svelte';
|
||||
import ChatAttachmentsPreviewCurrentItemUnavailable from './ChatAttachmentsPreviewCurrentItemUnavailable.svelte';
|
||||
|
||||
|
|
@ -11,11 +12,13 @@
|
|||
currentItem: ChatAttachmentDisplayItem | null;
|
||||
isImage: boolean;
|
||||
isAudio: boolean;
|
||||
isVideo: boolean;
|
||||
isPdf: boolean;
|
||||
isText: boolean;
|
||||
displayPreview: string | undefined;
|
||||
displayTextContent: string | undefined;
|
||||
audioSrc: string | null;
|
||||
videoSrc: string | null;
|
||||
language: string;
|
||||
hasVisionModality: boolean;
|
||||
activeModelId?: string;
|
||||
|
|
@ -25,21 +28,25 @@
|
|||
currentItem,
|
||||
isImage,
|
||||
isAudio,
|
||||
isVideo,
|
||||
isPdf,
|
||||
isText,
|
||||
displayPreview,
|
||||
displayTextContent,
|
||||
audioSrc,
|
||||
videoSrc,
|
||||
language,
|
||||
hasVisionModality,
|
||||
activeModelId
|
||||
}: Props = $props();
|
||||
|
||||
let IconComponent = $derived(
|
||||
isImage ? Image : isText || isPdf ? FileText : isAudio ? Music : FileIcon
|
||||
isImage ? Image : isText || isPdf ? FileText : isAudio ? Music : isVideo ? Video : FileIcon
|
||||
);
|
||||
|
||||
let isUnavailable = $derived(!isPdf && !isImage && !(isText && displayTextContent) && !isAudio);
|
||||
let isUnavailable = $derived(
|
||||
!isPdf && !isImage && !(isText && displayTextContent) && !isAudio && !isVideo
|
||||
);
|
||||
</script>
|
||||
|
||||
{#if currentItem}
|
||||
|
|
@ -58,6 +65,8 @@
|
|||
<ChatAttachmentsPreviewCurrentItemText {displayTextContent} {language} />
|
||||
{:else if isAudio}
|
||||
<ChatAttachmentsPreviewCurrentItemAudio {currentItem} {audioSrc} />
|
||||
{:else if isVideo}
|
||||
<ChatAttachmentsPreviewCurrentItemVideo {currentItem} {videoSrc} />
|
||||
{:else if isUnavailable}
|
||||
<ChatAttachmentsPreviewCurrentItemUnavailable {IconComponent} />
|
||||
{/if}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
<script lang="ts">
|
||||
import { Video } from '@lucide/svelte';
|
||||
|
||||
interface Props {
|
||||
currentItem: { name?: string } | null;
|
||||
videoSrc: string | null;
|
||||
}
|
||||
|
||||
let { currentItem, videoSrc }: Props = $props();
|
||||
</script>
|
||||
|
||||
<div class="flex flex-1 items-center justify-center p-8">
|
||||
<div class="w-full max-w-md text-center">
|
||||
<Video class="mx-auto mb-4 h-16 w-16 text-white/50" />
|
||||
|
||||
{#if videoSrc}
|
||||
<video controls class="mb-4 w-full" src={videoSrc}>
|
||||
Your browser does not support the video element.
|
||||
</video>
|
||||
{:else}
|
||||
<p class="mb-4 text-white/70">Video preview not available</p>
|
||||
{/if}
|
||||
|
||||
<p class="text-sm text-white/50">{currentItem?.name || 'Video'}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
<script lang="ts">
|
||||
import { Music, FileText } from '@lucide/svelte';
|
||||
import { Music, Video, FileText } from '@lucide/svelte';
|
||||
import { HorizontalScrollCarousel } from '$lib/components/app/misc';
|
||||
|
||||
interface PreviewItem {
|
||||
|
|
@ -7,6 +7,7 @@
|
|||
name: string;
|
||||
isImage: boolean;
|
||||
isAudio: boolean;
|
||||
isVideo: boolean;
|
||||
preview?: string;
|
||||
}
|
||||
|
||||
|
|
@ -49,6 +50,8 @@
|
|||
>
|
||||
{#if item.isAudio}
|
||||
<Music class="h-4 w-4 text-white/70" />
|
||||
{:else if item.isVideo}
|
||||
<Video class="h-4 w-4 text-white/70" />
|
||||
{:else}
|
||||
<FileText class="h-4 w-4 text-white/70" />
|
||||
{/if}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
class?: string;
|
||||
disabled?: boolean;
|
||||
hasAudioModality?: boolean;
|
||||
hasVideoModality?: boolean;
|
||||
hasVisionModality?: boolean;
|
||||
hasMcpPromptsSupport?: boolean;
|
||||
hasMcpResourcesSupport?: boolean;
|
||||
|
|
@ -37,6 +38,7 @@
|
|||
class: className = '',
|
||||
disabled = false,
|
||||
hasAudioModality = false,
|
||||
hasVideoModality = false,
|
||||
hasVisionModality = false,
|
||||
hasMcpPromptsSupport = false,
|
||||
hasMcpResourcesSupport = false,
|
||||
|
|
@ -58,6 +60,7 @@
|
|||
() => ({
|
||||
hasVisionModality,
|
||||
hasAudioModality,
|
||||
hasVideoModality,
|
||||
hasMcpPromptsSupport,
|
||||
hasMcpResourcesSupport
|
||||
}),
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
class?: string;
|
||||
disabled?: boolean;
|
||||
hasAudioModality?: boolean;
|
||||
hasVideoModality?: boolean;
|
||||
hasVisionModality?: boolean;
|
||||
hasMcpPromptsSupport?: boolean;
|
||||
hasMcpResourcesSupport?: boolean;
|
||||
|
|
@ -34,6 +35,7 @@
|
|||
disabled = false,
|
||||
hasAudioModality = false,
|
||||
hasVisionModality = false,
|
||||
hasVideoModality = false,
|
||||
hasMcpPromptsSupport = false,
|
||||
hasMcpResourcesSupport = false,
|
||||
onFileUpload,
|
||||
|
|
@ -49,6 +51,7 @@
|
|||
() => ({
|
||||
hasVisionModality,
|
||||
hasAudioModality,
|
||||
hasVideoModality,
|
||||
hasMcpPromptsSupport,
|
||||
hasMcpResourcesSupport
|
||||
}),
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
interface Props {
|
||||
disabled?: boolean;
|
||||
hasAudioModality?: boolean;
|
||||
hasVideoModality?: boolean;
|
||||
hasMcpPromptsSupport?: boolean;
|
||||
hasMcpResourcesSupport?: boolean;
|
||||
hasVisionModality?: boolean;
|
||||
|
|
@ -20,6 +21,7 @@
|
|||
let {
|
||||
disabled = false,
|
||||
hasAudioModality = false,
|
||||
hasVideoModality = false,
|
||||
hasMcpPromptsSupport = false,
|
||||
hasMcpResourcesSupport = false,
|
||||
hasVisionModality = false,
|
||||
|
|
@ -37,6 +39,7 @@
|
|||
<ChatFormActionAddSheet
|
||||
{disabled}
|
||||
{hasAudioModality}
|
||||
{hasVideoModality}
|
||||
{hasVisionModality}
|
||||
{hasMcpPromptsSupport}
|
||||
{hasMcpResourcesSupport}
|
||||
|
|
@ -52,6 +55,7 @@
|
|||
<ChatFormActionAddDropdown
|
||||
{disabled}
|
||||
{hasAudioModality}
|
||||
{hasVideoModality}
|
||||
{hasVisionModality}
|
||||
{hasMcpPromptsSupport}
|
||||
{hasMcpResourcesSupport}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
disabled?: boolean;
|
||||
forceForegroundText?: boolean;
|
||||
hasAudioModality?: boolean;
|
||||
hasVideoModality?: boolean;
|
||||
hasVisionModality?: boolean;
|
||||
hasModelSelected?: boolean;
|
||||
isSelectedModelInCache?: boolean;
|
||||
|
|
@ -23,6 +24,7 @@
|
|||
disabled = false,
|
||||
forceForegroundText = false,
|
||||
hasAudioModality = $bindable(false),
|
||||
hasVideoModality = $bindable(false),
|
||||
hasVisionModality = $bindable(false),
|
||||
hasModelSelected = $bindable(false),
|
||||
isSelectedModelInCache = $bindable(true),
|
||||
|
|
@ -95,6 +97,10 @@
|
|||
hasAudioModality = activeModelId ? modelsStore.modelSupportsAudio(activeModelId) : false;
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
hasVideoModality = activeModelId ? modelsStore.modelSupportsVideo(activeModelId) : false;
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
void modelPropsVersion;
|
||||
|
||||
|
|
|
|||
|
|
@ -66,6 +66,7 @@
|
|||
});
|
||||
|
||||
let hasAudioModality = $state(false);
|
||||
let hasVideoModality = $state(false);
|
||||
let hasVisionModality = $state(false);
|
||||
let hasModelSelected = $state(false);
|
||||
let isSelectedModelInCache = $state(true);
|
||||
|
|
@ -94,6 +95,7 @@
|
|||
<ChatFormActionsAdd
|
||||
{disabled}
|
||||
{hasAudioModality}
|
||||
{hasVideoModality}
|
||||
{hasVisionModality}
|
||||
{hasMcpPromptsSupport}
|
||||
{hasMcpResourcesSupport}
|
||||
|
|
@ -111,6 +113,7 @@
|
|||
{disabled}
|
||||
bind:this={selectorModelRef}
|
||||
bind:hasAudioModality
|
||||
bind:hasVideoModality
|
||||
bind:hasVisionModality
|
||||
bind:hasModelSelected
|
||||
bind:isSelectedModelInCache
|
||||
|
|
|
|||
|
|
@ -379,9 +379,6 @@
|
|||
border-radius: 1rem;
|
||||
background: hsl(var(--muted) / 0.3);
|
||||
color: var(--foreground);
|
||||
font-family:
|
||||
ui-monospace, SFMono-Regular, 'SF Mono', Monaco, 'Cascadia Code', 'Roboto Mono', Consolas,
|
||||
'Liberation Mono', Menlo, monospace;
|
||||
font-size: 0.875rem;
|
||||
line-height: 1.6;
|
||||
white-space: pre-wrap;
|
||||
|
|
|
|||
|
|
@ -144,6 +144,16 @@
|
|||
return false;
|
||||
});
|
||||
|
||||
let hasVideoModality = $derived.by(() => {
|
||||
if (activeModelId) {
|
||||
void modelPropsVersion;
|
||||
|
||||
return modelsStore.modelSupportsVideo(activeModelId);
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
|
||||
let hasVisionModality = $derived.by(() => {
|
||||
if (activeModelId) {
|
||||
void modelPropsVersion;
|
||||
|
|
@ -284,7 +294,11 @@
|
|||
}
|
||||
|
||||
// Use model-specific capabilities for file validation
|
||||
const capabilities = { hasVision: hasVisionModality, hasAudio: hasAudioModality };
|
||||
const capabilities = {
|
||||
hasVision: hasVisionModality,
|
||||
hasAudio: hasAudioModality,
|
||||
hasVideo: hasVideoModality
|
||||
};
|
||||
const { supportedFiles, unsupportedFiles, modalityReasons } = filterFilesByModalities(
|
||||
generallySupported,
|
||||
capabilities
|
||||
|
|
@ -297,6 +311,7 @@
|
|||
|
||||
if (hasVisionModality) supportedTypes.push('images');
|
||||
if (hasAudioModality) supportedTypes.push('audio files');
|
||||
if (hasVideoModality) supportedTypes.push('video files');
|
||||
|
||||
fileErrorData = {
|
||||
generallyUnsupported,
|
||||
|
|
|
|||
|
|
@ -742,9 +742,6 @@
|
|||
padding: 0.125rem 0.375rem;
|
||||
border-radius: 0.375rem;
|
||||
font-size: 0.875rem;
|
||||
font-family:
|
||||
ui-monospace, SFMono-Regular, 'SF Mono', Monaco, 'Cascadia Code', 'Roboto Mono', Consolas,
|
||||
'Liberation Mono', Menlo, monospace;
|
||||
}
|
||||
|
||||
div :global(pre) {
|
||||
|
|
|
|||
|
|
@ -80,12 +80,6 @@
|
|||
</div>
|
||||
|
||||
<style>
|
||||
.code-preview-wrapper {
|
||||
font-family:
|
||||
ui-monospace, SFMono-Regular, 'SF Mono', Monaco, 'Cascadia Code', 'Roboto Mono', Consolas,
|
||||
'Liberation Mono', Menlo, monospace;
|
||||
}
|
||||
|
||||
.code-preview-wrapper pre {
|
||||
background: transparent;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -52,6 +52,15 @@ export const ATTACHMENT_FILE_ITEMS: AttachmentMenuItem[] = [
|
|||
disabledTooltip: 'Audio files processing requires an audio model',
|
||||
action: AttachmentAction.FILE_UPLOAD
|
||||
},
|
||||
{
|
||||
id: AttachmentMenuItemId.VIDEO,
|
||||
label: 'Video Files',
|
||||
icon: FILE_TYPE_ICONS.video,
|
||||
class: 'video-button',
|
||||
enabledWhen: AttachmentItemEnabledWhen.HAS_VIDEO_MODALITY,
|
||||
disabledTooltip: 'Video files processing requires a video model',
|
||||
action: AttachmentAction.FILE_UPLOAD
|
||||
},
|
||||
{
|
||||
id: AttachmentMenuItemId.TEXT,
|
||||
label: 'Text Files',
|
||||
|
|
|
|||
|
|
@ -8,13 +8,15 @@ import {
|
|||
FileText as FileTextIcon,
|
||||
Image as ImageIcon,
|
||||
Eye as VisionIcon,
|
||||
Mic as AudioIcon
|
||||
Mic as AudioIcon,
|
||||
Video as VideoIcon
|
||||
} from '@lucide/svelte';
|
||||
import { FileTypeCategory, ModelModality } from '$lib/enums';
|
||||
|
||||
export const FILE_TYPE_ICONS = {
|
||||
[FileTypeCategory.IMAGE]: ImageIcon,
|
||||
[FileTypeCategory.AUDIO]: AudioIcon,
|
||||
[FileTypeCategory.VIDEO]: VideoIcon,
|
||||
[FileTypeCategory.TEXT]: FileTextIcon,
|
||||
[FileTypeCategory.PDF]: FileIcon
|
||||
} as const;
|
||||
|
|
@ -23,10 +25,12 @@ export const DEFAULT_FILE_ICON = FileIcon;
|
|||
|
||||
export const MODALITY_ICONS = {
|
||||
[ModelModality.VISION]: VisionIcon,
|
||||
[ModelModality.AUDIO]: AudioIcon
|
||||
[ModelModality.AUDIO]: AudioIcon,
|
||||
[ModelModality.VIDEO]: VideoIcon
|
||||
} as const;
|
||||
|
||||
export const MODALITY_LABELS = {
|
||||
[ModelModality.VISION]: 'Vision',
|
||||
[ModelModality.AUDIO]: 'Audio'
|
||||
[ModelModality.AUDIO]: 'Audio',
|
||||
[ModelModality.VIDEO]: 'Video'
|
||||
} as const;
|
||||
|
|
|
|||
|
|
@ -13,10 +13,12 @@ import {
|
|||
FileTypePdf,
|
||||
FileTypeText,
|
||||
MimeTypeAudio,
|
||||
MimeTypeVideo,
|
||||
MimeTypeImage,
|
||||
MimeTypeApplication,
|
||||
MimeTypeText
|
||||
} from '$lib/enums';
|
||||
import { FileExtensionVideo, FileTypeVideo } from '$lib/enums/files';
|
||||
|
||||
// File type configuration using enums
|
||||
export const AUDIO_FILE_TYPES = {
|
||||
|
|
@ -30,6 +32,17 @@ export const AUDIO_FILE_TYPES = {
|
|||
}
|
||||
} as const;
|
||||
|
||||
export const VIDEO_FILE_TYPES = {
|
||||
[FileTypeVideo.MP4]: {
|
||||
extensions: [FileExtensionVideo.MP4],
|
||||
mimeTypes: [MimeTypeVideo.MP4]
|
||||
},
|
||||
[FileTypeVideo.OGG]: {
|
||||
extensions: [FileExtensionVideo.OGG],
|
||||
mimeTypes: [MimeTypeVideo.OGG]
|
||||
}
|
||||
} as const;
|
||||
|
||||
export const IMAGE_FILE_TYPES = {
|
||||
[FileTypeImage.JPEG]: {
|
||||
extensions: [FileExtensionImage.JPG, FileExtensionImage.JPEG],
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
export enum AttachmentType {
|
||||
AUDIO = 'AUDIO',
|
||||
IMAGE = 'IMAGE',
|
||||
VIDEO = 'VIDEO',
|
||||
MCP_PROMPT = 'MCP_PROMPT',
|
||||
MCP_RESOURCE = 'MCP_RESOURCE',
|
||||
PDF = 'PDF',
|
||||
|
|
@ -18,6 +19,7 @@ export enum AttachmentType {
|
|||
export enum AttachmentMenuItemId {
|
||||
IMAGES = 'images',
|
||||
AUDIO = 'audio',
|
||||
VIDEO = 'video',
|
||||
TEXT = 'text',
|
||||
PDF = 'pdf',
|
||||
SYSTEM_MESSAGE = 'system-message',
|
||||
|
|
@ -31,7 +33,8 @@ export enum AttachmentMenuItemId {
|
|||
export enum AttachmentItemEnabledWhen {
|
||||
ALWAYS = 'always',
|
||||
HAS_VISION_MODALITY = 'hasVisionModality',
|
||||
HAS_AUDIO_MODALITY = 'hasAudioModality'
|
||||
HAS_AUDIO_MODALITY = 'hasAudioModality',
|
||||
HAS_VIDEO_MODALITY = 'hasVideoModality'
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -39,7 +39,8 @@ export enum MessageType {
|
|||
export enum ContentPartType {
|
||||
TEXT = 'text',
|
||||
IMAGE_URL = 'image_url',
|
||||
INPUT_AUDIO = 'input_audio'
|
||||
INPUT_AUDIO = 'input_audio',
|
||||
INPUT_VIDEO = 'input_video'
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
export enum FileTypeCategory {
|
||||
IMAGE = 'image',
|
||||
AUDIO = 'audio',
|
||||
VIDEO = 'video',
|
||||
PDF = 'pdf',
|
||||
TEXT = 'text'
|
||||
}
|
||||
|
|
@ -33,6 +34,11 @@ export enum FileTypeAudio {
|
|||
WEBM = 'webm'
|
||||
}
|
||||
|
||||
export enum FileTypeVideo {
|
||||
MP4 = 'mp4',
|
||||
OGG = 'ogg'
|
||||
}
|
||||
|
||||
export enum FileTypePdf {
|
||||
PDF = 'pdf'
|
||||
}
|
||||
|
|
@ -92,6 +98,11 @@ export enum FileExtensionAudio {
|
|||
WAV = '.wav'
|
||||
}
|
||||
|
||||
export enum FileExtensionVideo {
|
||||
MP4 = '.mp4',
|
||||
OGG = '.ogg'
|
||||
}
|
||||
|
||||
export enum FileExtensionPdf {
|
||||
PDF = '.pdf'
|
||||
}
|
||||
|
|
@ -176,6 +187,11 @@ export enum MimeTypeAudio {
|
|||
WEBM_OPUS = 'audio/webm;codecs=opus'
|
||||
}
|
||||
|
||||
export enum MimeTypeVideo {
|
||||
MP4 = 'video/mp4',
|
||||
OGG = 'video/ogg'
|
||||
}
|
||||
|
||||
export enum MimeTypeImage {
|
||||
JPEG = 'image/jpeg',
|
||||
JPG = 'image/jpg',
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ export {
|
|||
UriPattern,
|
||||
MimeTypeApplication,
|
||||
MimeTypeAudio,
|
||||
MimeTypeVideo,
|
||||
MimeTypeImage,
|
||||
MimeTypeText,
|
||||
SpecialFileType
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
export enum ModelModality {
|
||||
TEXT = 'TEXT',
|
||||
AUDIO = 'AUDIO',
|
||||
VISION = 'VISION'
|
||||
VISION = 'VISION',
|
||||
VIDEO = 'VIDEO'
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import { AttachmentAction } from '$lib/enums';
|
|||
export interface AttachmentModalityFlags {
|
||||
hasVisionModality: boolean;
|
||||
hasAudioModality: boolean;
|
||||
hasVideoModality: boolean;
|
||||
hasMcpPromptsSupport: boolean;
|
||||
hasMcpResourcesSupport: boolean;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -888,6 +888,25 @@ export class ChatService {
|
|||
});
|
||||
}
|
||||
|
||||
const videoFiles = message.extra.filter(
|
||||
(extra: DatabaseMessageExtra): extra is DatabaseMessageExtraVideoFile =>
|
||||
extra.type === AttachmentType.VIDEO
|
||||
);
|
||||
|
||||
for (const video of videoFiles) {
|
||||
contentParts.push({
|
||||
type: ContentPartType.INPUT_VIDEO,
|
||||
input_video: {
|
||||
data: video.base64Data,
|
||||
format: video.mimeType.includes('mp4')
|
||||
? 'mp4'
|
||||
: video.mimeType.includes('ogg')
|
||||
? 'ogg'
|
||||
: 'auto'
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const pdfFiles = message.extra.filter(
|
||||
(extra: DatabaseMessageExtra): extra is DatabaseMessageExtraPdfFile =>
|
||||
extra.type === AttachmentType.PDF
|
||||
|
|
|
|||
|
|
@ -148,7 +148,8 @@ class ModelsStore {
|
|||
if (props?.modalities) {
|
||||
return {
|
||||
vision: props.modalities.vision ?? false,
|
||||
audio: props.modalities.audio ?? false
|
||||
audio: props.modalities.audio ?? false,
|
||||
video: props.modalities.video ?? false
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -169,6 +170,13 @@ class ModelsStore {
|
|||
return this.getModelModalities(modelId)?.audio ?? false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model supports video modality
|
||||
*/
|
||||
modelSupportsVideo(modelId: string): boolean {
|
||||
return this.getModelModalities(modelId)?.video ?? false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get model modalities as an array of ModelModality enum values
|
||||
*/
|
||||
|
|
@ -180,6 +188,7 @@ class ModelsStore {
|
|||
|
||||
if (modalities.vision) result.push(ModelModality.VISION);
|
||||
if (modalities.audio) result.push(ModelModality.AUDIO);
|
||||
if (modalities.video) result.push(ModelModality.VIDEO);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
@ -316,7 +325,8 @@ class ModelsStore {
|
|||
if (serverStore.isModelMode && this.models.length > 0 && serverProps?.modalities) {
|
||||
const modalities: ModelModalities = {
|
||||
vision: serverProps.modalities.vision ?? false,
|
||||
audio: serverProps.modalities.audio ?? false
|
||||
audio: serverProps.modalities.audio ?? false,
|
||||
video: serverProps.modalities.video ?? false
|
||||
};
|
||||
this.modelPropsCache.set(this.models[0].model, serverProps);
|
||||
this.models = this.models.map((model, index) =>
|
||||
|
|
@ -410,7 +420,8 @@ class ModelsStore {
|
|||
|
||||
const modalities: ModelModalities = {
|
||||
vision: props.modalities.vision ?? false,
|
||||
audio: props.modalities.audio ?? false
|
||||
audio: props.modalities.audio ?? false,
|
||||
video: props.modalities.video ?? false
|
||||
};
|
||||
|
||||
return { ...model, modalities };
|
||||
|
|
@ -529,7 +540,8 @@ class ModelsStore {
|
|||
|
||||
const modalities: ModelModalities = {
|
||||
vision: props.modalities.vision ?? false,
|
||||
audio: props.modalities.audio ?? false
|
||||
audio: props.modalities.audio ?? false,
|
||||
video: props.modalities.video ?? false
|
||||
};
|
||||
|
||||
this.models = this.models.map((model) =>
|
||||
|
|
|
|||
5
tools/ui/src/lib/types/api.d.ts
vendored
5
tools/ui/src/lib/types/api.d.ts
vendored
|
|
@ -22,6 +22,10 @@ export interface ApiChatMessageContentPart {
|
|||
data: string;
|
||||
format: 'wav' | 'mp3';
|
||||
};
|
||||
input_video?: {
|
||||
data: string;
|
||||
format: 'mp4' | 'ogg' | 'auto';
|
||||
};
|
||||
}
|
||||
|
||||
export interface ApiContextSizeError {
|
||||
|
|
@ -190,6 +194,7 @@ export interface ApiLlamaCppServerProps {
|
|||
modalities: {
|
||||
vision: boolean;
|
||||
audio: boolean;
|
||||
video: boolean;
|
||||
};
|
||||
chat_template: string;
|
||||
bos_token: string;
|
||||
|
|
|
|||
7
tools/ui/src/lib/types/common.d.ts
vendored
7
tools/ui/src/lib/types/common.d.ts
vendored
|
|
@ -64,4 +64,9 @@ export interface ParsedClipboardContent {
|
|||
mcpPromptAttachments: ClipboardMcpPromptAttachment[];
|
||||
}
|
||||
|
||||
export type MimeTypeUnion = MimeTypeAudio | MimeTypeImage | MimeTypeApplication | MimeTypeText;
|
||||
export type MimeTypeUnion =
|
||||
| MimeTypeAudio
|
||||
| MimeTypeVideo
|
||||
| MimeTypeImage
|
||||
| MimeTypeApplication
|
||||
| MimeTypeText;
|
||||
|
|
|
|||
9
tools/ui/src/lib/types/database.d.ts
vendored
9
tools/ui/src/lib/types/database.d.ts
vendored
|
|
@ -23,6 +23,14 @@ export interface DatabaseMessageExtraAudioFile {
|
|||
mimeType: string;
|
||||
}
|
||||
|
||||
export interface DatabaseMessageExtraVideoFile {
|
||||
type: AttachmentType.VIDEO;
|
||||
name: string;
|
||||
size?: number;
|
||||
base64Data: string;
|
||||
mimeType: string;
|
||||
}
|
||||
|
||||
export interface DatabaseMessageExtraImageFile {
|
||||
type: AttachmentType.IMAGE;
|
||||
name: string;
|
||||
|
|
@ -82,6 +90,7 @@ export type DatabaseMessageExtra =
|
|||
| DatabaseMessageExtraImageFile
|
||||
| DatabaseMessageExtraTextFile
|
||||
| DatabaseMessageExtraAudioFile
|
||||
| DatabaseMessageExtraVideoFile
|
||||
| DatabaseMessageExtraPdfFile
|
||||
| DatabaseMessageExtraMcpPrompt
|
||||
| DatabaseMessageExtraMcpResource
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ export type {
|
|||
McpServerOverride,
|
||||
DatabaseConversation,
|
||||
DatabaseMessageExtraAudioFile,
|
||||
DatabaseMessageExtraVideoFile,
|
||||
DatabaseMessageExtraImageFile,
|
||||
DatabaseMessageExtraLegacyContext,
|
||||
DatabaseMessageExtraMcpPrompt,
|
||||
|
|
|
|||
2
tools/ui/src/lib/types/models.d.ts
vendored
2
tools/ui/src/lib/types/models.d.ts
vendored
|
|
@ -3,6 +3,7 @@ import type { ApiModelDataEntry, ApiModelDetails } from '$lib/types/api';
|
|||
export interface ModelModalities {
|
||||
vision: boolean;
|
||||
audio: boolean;
|
||||
video: boolean;
|
||||
}
|
||||
|
||||
export interface ModelOption {
|
||||
|
|
@ -35,4 +36,5 @@ export interface ParsedModelId {
|
|||
export interface ModalityCapabilities {
|
||||
hasVision: boolean;
|
||||
hasAudio: boolean;
|
||||
hasVideo: boolean;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -103,3 +103,24 @@ export function isAudioFile(
|
|||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines if an attachment or uploaded file is a video file
|
||||
* @param uploadedFile - Optional uploaded file
|
||||
* @param attachment - Optional database attachment
|
||||
* @returns true if the file is a video file
|
||||
*/
|
||||
export function isVideoFile(
|
||||
attachment?: DatabaseMessageExtra,
|
||||
uploadedFile?: ChatUploadedFile
|
||||
): boolean {
|
||||
if (uploadedFile) {
|
||||
return getUploadedFileCategory(uploadedFile) === FileTypeCategory.VIDEO;
|
||||
}
|
||||
|
||||
if (attachment) {
|
||||
return attachment.type === AttachmentType.VIDEO;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -89,6 +89,21 @@ export async function parseFilesToMessageExtras(
|
|||
} catch (error) {
|
||||
console.error(`Failed to process audio file ${file.name}:`, error);
|
||||
}
|
||||
} else if (getFileTypeCategory(file.type) === FileTypeCategory.VIDEO) {
|
||||
// Process video files (MP4, etc)
|
||||
try {
|
||||
const base64Data = await readFileAsBase64(file.file);
|
||||
|
||||
extras.push({
|
||||
type: AttachmentType.VIDEO,
|
||||
name: file.name,
|
||||
size: file.size,
|
||||
base64Data: base64Data,
|
||||
mimeType: file.type
|
||||
});
|
||||
} catch (error) {
|
||||
console.error(`Failed to process video file ${file.name}:`, error);
|
||||
}
|
||||
} else if (getFileTypeCategory(file.type) === FileTypeCategory.PDF) {
|
||||
try {
|
||||
// Always get base64 data for preview functionality
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import {
|
||||
AUDIO_FILE_TYPES,
|
||||
VIDEO_FILE_TYPES,
|
||||
IMAGE_FILE_TYPES,
|
||||
PDF_FILE_TYPES,
|
||||
TEXT_FILE_TYPES
|
||||
|
|
@ -12,6 +13,7 @@ import {
|
|||
FileTypeCategory,
|
||||
MimeTypeApplication,
|
||||
MimeTypeAudio,
|
||||
MimeTypeVideo,
|
||||
MimeTypeImage,
|
||||
MimeTypeText
|
||||
} from '$lib/enums';
|
||||
|
|
@ -35,6 +37,11 @@ export function getFileTypeCategory(mimeType: string): FileTypeCategory | null {
|
|||
case MimeTypeAudio.WEBM_OPUS:
|
||||
return FileTypeCategory.AUDIO;
|
||||
|
||||
// Video
|
||||
case MimeTypeVideo.MP4:
|
||||
case MimeTypeVideo.OGG:
|
||||
return FileTypeCategory.VIDEO;
|
||||
|
||||
// PDF
|
||||
case MimeTypeApplication.PDF:
|
||||
return FileTypeCategory.PDF;
|
||||
|
|
@ -179,6 +186,12 @@ export function getFileTypeByExtension(filename: string): string | null {
|
|||
}
|
||||
}
|
||||
|
||||
for (const [key, type] of Object.entries(VIDEO_FILE_TYPES)) {
|
||||
if ((type.extensions as readonly string[]).includes(extension)) {
|
||||
return `${FileTypeCategory.VIDEO}:${key}`;
|
||||
}
|
||||
}
|
||||
|
||||
for (const [key, type] of Object.entries(PDF_FILE_TYPES)) {
|
||||
if ((type.extensions as readonly string[]).includes(extension)) {
|
||||
return `${FileTypeCategory.PDF}:${key}`;
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ export { validateApiKey } from './api-key-validation';
|
|||
|
||||
// Attachment utilities
|
||||
export { getAttachmentDisplayItems, isMcpPrompt, isMcpResource } from './attachment-display';
|
||||
export { isTextFile, isImageFile, isPdfFile, isAudioFile } from './attachment-type';
|
||||
export { isTextFile, isImageFile, isPdfFile, isAudioFile, isVideoFile } from './attachment-type';
|
||||
|
||||
// Textarea utilities
|
||||
export { default as autoResizeTextarea } from './autoresize-textarea';
|
||||
|
|
|
|||
|
|
@ -45,6 +45,10 @@ export function isFileTypeSupportedByModel(
|
|||
// Audio files require audio support
|
||||
return capabilities.hasAudio;
|
||||
|
||||
case FileTypeCategory.VIDEO:
|
||||
// Video files require video support
|
||||
return capabilities.hasVideo;
|
||||
|
||||
default:
|
||||
// Unknown categories - be conservative and allow
|
||||
return true;
|
||||
|
|
@ -69,7 +73,7 @@ export function filterFilesByModalities(
|
|||
const unsupportedFiles: File[] = [];
|
||||
const modalityReasons: Record<string, string> = {};
|
||||
|
||||
const { hasVision, hasAudio } = capabilities;
|
||||
const { hasVision, hasAudio, hasVideo } = capabilities;
|
||||
|
||||
for (const file of files) {
|
||||
const category = getFileTypeCategory(file.type);
|
||||
|
|
@ -91,6 +95,13 @@ export function filterFilesByModalities(
|
|||
}
|
||||
break;
|
||||
|
||||
case FileTypeCategory.VIDEO:
|
||||
if (!hasVideo) {
|
||||
isSupported = false;
|
||||
reason = 'Video files require a video-capable model';
|
||||
}
|
||||
break;
|
||||
|
||||
case FileTypeCategory.TEXT:
|
||||
case FileTypeCategory.PDF:
|
||||
// Always supported
|
||||
|
|
@ -127,7 +138,7 @@ export function generateModalityErrorMessage(
|
|||
): string {
|
||||
if (unsupportedFiles.length === 0) return '';
|
||||
|
||||
const { hasVision, hasAudio } = capabilities;
|
||||
const { hasVision, hasAudio, hasVideo } = capabilities;
|
||||
|
||||
let message = '';
|
||||
|
||||
|
|
@ -144,6 +155,7 @@ export function generateModalityErrorMessage(
|
|||
const supportedTypes: string[] = ['text files', 'PDFs'];
|
||||
if (hasVision) supportedTypes.push('images');
|
||||
if (hasAudio) supportedTypes.push('audio files');
|
||||
if (hasVideo) supportedTypes.push('video files');
|
||||
|
||||
message += ` This model supports: ${supportedTypes.join(', ')}.`;
|
||||
|
||||
|
|
|
|||
|
|
@ -117,6 +117,10 @@ export async function processFilesToChatUploaded(
|
|||
// Generate preview URL for audio files
|
||||
const preview = await readFileAsDataURL(file);
|
||||
results.push({ ...base, preview });
|
||||
} else if (getFileTypeCategory(file.type) === FileTypeCategory.VIDEO) {
|
||||
// Generate preview URL for video files
|
||||
const preview = await readFileAsDataURL(file);
|
||||
results.push({ ...base, preview });
|
||||
} else {
|
||||
// Fallback: treat unknown files as text
|
||||
try {
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ export function mockServerProps(props: Partial<ApiLlamaCppServerProps>): void {
|
|||
model_path: props.model_path || 'test-model',
|
||||
modalities: {
|
||||
vision: props.modalities?.vision ?? false,
|
||||
audio: props.modalities?.audio ?? false
|
||||
audio: props.modalities?.audio ?? false,
|
||||
video: props.modalities?.video ?? false
|
||||
},
|
||||
...props
|
||||
} as ApiLlamaCppServerProps;
|
||||
|
|
@ -26,11 +27,14 @@ export function mockServerProps(props: Partial<ApiLlamaCppServerProps>): void {
|
|||
// Also mock modelsStore methods for modality checking
|
||||
const vision = props.modalities?.vision ?? false;
|
||||
const audio = props.modalities?.audio ?? false;
|
||||
const video = props.modalities?.video ?? false;
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(modelsStore as any).modelSupportsVision = () => vision;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(modelsStore as any).modelSupportsAudio = () => audio;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(modelsStore as any).modelSupportsVideo = () => video;
|
||||
|
||||
// Mock models list with a test model so activeModelId can be resolved
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
|
|
@ -55,7 +59,8 @@ export function resetServerStore(): void {
|
|||
model_path: '',
|
||||
modalities: {
|
||||
vision: false,
|
||||
audio: false
|
||||
audio: false,
|
||||
video: false
|
||||
}
|
||||
} as ApiLlamaCppServerProps;
|
||||
(serverStore as unknown as { error: string }).error = '';
|
||||
|
|
@ -76,6 +81,6 @@ export const mockConfigs = {
|
|||
modalities: { vision: true, audio: true }
|
||||
},
|
||||
noModalities: {
|
||||
modalities: { vision: false, audio: false }
|
||||
modalities: { vision: false, audio: false, video: false }
|
||||
}
|
||||
} as const;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue