Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	common/CMakeLists.txt
#	docs/backend/SYCL.md
#	ggml/CMakeLists.txt
#	ggml/src/ggml-sycl/CMakeLists.txt
#	ggml/src/ggml-sycl/binbcast.cpp
#	ggml/src/ggml-sycl/convert.cpp
#	ggml/src/ggml-sycl/dequantize.hpp
#	ggml/src/ggml-sycl/dmmv.cpp
#	ggml/src/ggml-sycl/gemm.hpp
#	ggml/src/ggml-sycl/ggml-sycl.cpp
#	ggml/src/ggml-sycl/mmvq.cpp
#	ggml/src/ggml-sycl/vecdotq.hpp
#	ggml/src/ggml-vulkan/CMakeLists.txt
#	ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt
#	ggml/src/gguf.cpp
#	scripts/compare-llama-bench.py
#	tests/CMakeLists.txt
#	tests/test-chat.cpp
#	tools/llama-bench/llama-bench.cpp
#	tools/server/README.md
This commit is contained in:
Concedo 2025-05-16 15:30:31 +08:00
commit e5d26a2356
47 changed files with 2671 additions and 504 deletions

View file

@ -6,6 +6,15 @@
#include <optional> #include <optional>
static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) {
auto time = std::chrono::system_clock::to_time_t(now);
auto local_time = *std::localtime(&time);
std::ostringstream ss;
ss << std::put_time(&local_time, format.c_str());
auto res = ss.str();
return res;
}
typedef minja::chat_template common_chat_template; typedef minja::chat_template common_chat_template;
struct common_chat_templates { struct common_chat_templates {
@ -24,6 +33,7 @@ struct templates_params {
std::string grammar; std::string grammar;
bool add_generation_prompt = true; bool add_generation_prompt = true;
bool extract_reasoning = true; bool extract_reasoning = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
}; };
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) { common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
@ -939,78 +949,83 @@ static void expect_tool_parameters(const std::string & name, const json & parame
} }
} }
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) { static common_chat_params common_chat_params_init_llama_3_x(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
auto builtin_tools = json::array(); auto builtin_tools = json::array();
common_chat_params data; common_chat_params data;
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; if (!inputs.tools.is_null()) {
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
std::vector<std::string> tool_rules; data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) { auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") { if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
expect_tool_parameters(name, parameters, {"query"}); expect_tool_parameters(name, parameters, {"query"});
} else if (name == "python" || name == "code_interpreter") { } else if (name == "python" || name == "code_interpreter") {
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
expect_tool_parameters(name, parameters, {"code"}); expect_tool_parameters(name, parameters, {"code"});
} else { } else {
return false; return false;
}
std::vector<std::string> kvs;
for (const auto & [key, value] : parameters.at("properties").items()) {
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
}
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
builtin_tools.push_back(name);
return true;
};
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
if (allow_python_tag_builtin_tools) {
handle_builtin_tool(name, parameters);
}
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"{\" space "
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
"\"}\" space"));
});
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
});
if (!builtin_tools.empty()) {
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
data.preserved_tokens.push_back("<|python_tag|>");
} }
// Allow a few empty lines on top of the usual constrained json schema space rule.
std::vector<std::string> kvs; builder.add_rule("root", string_join(tool_rules, " | "));
for (const auto & [key, value] : parameters.at("properties").items()) { data.additional_stops.push_back("<|eom_id|>");
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
}
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
builtin_tools.push_back(name);
return true;
};
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
if (allow_python_tag_builtin_tools) {
handle_builtin_tool(name, parameters);
}
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"{\" space "
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
"\"}\" space"));
}); });
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name. data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
data.grammar_triggers.push_back({ ? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START, : COMMON_CHAT_FORMAT_LLAMA_3_X;
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*", } else {
}); data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
if (!builtin_tools.empty()) { }
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
data.preserved_tokens.push_back("<|python_tag|>");
}
// Allow a few empty lines on top of the usual constrained json schema space rule.
builder.add_rule("root", string_join(tool_rules, " | "));
});
data.additional_stops.push_back("<|eom_id|>");
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, { data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
{"date_string", format_time(inputs.now, "%d %b %Y")},
{"tools_in_user_message", false}, {"tools_in_user_message", false},
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools}, {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
}); });
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
: COMMON_CHAT_FORMAT_LLAMA_3_X;
return data; return data;
} }
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) { static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
@ -1150,7 +1165,7 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
LOG_DBG("%s\n", __func__); LOG_DBG("%s\n", __func__);
common_chat_params data; common_chat_params data;
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
{"datetime", "Jan 29 2025 13:00:00 GMT"}, {"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
}); });
if (inputs.tools.is_array() && !inputs.tools.empty()) { if (inputs.tools.is_array() && !inputs.tools.empty()) {
@ -1285,55 +1300,59 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) { static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
common_chat_params data; common_chat_params data;
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
std::string python_code_argument_name;
auto has_raw_python = false;
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; if (!inputs.tools.is_null()) {
data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::string python_code_argument_name;
std::vector<std::string> tool_rules; auto has_raw_python = false;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function"); data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
const auto & parameters = function.at("parameters"); data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::string name = function.at("name"); std::vector<std::string> tool_rules;
if (name == "python" || name == "ipython") { foreach_function(inputs.tools, [&](const json & tool) {
if (!parameters.contains("type")) { const auto & function = tool.at("function");
throw std::runtime_error("Missing type in python tool"); const auto & parameters = function.at("parameters");
} std::string name = function.at("name");
has_raw_python = true; if (name == "python" || name == "ipython") {
const auto & type = parameters.at("type"); if (!parameters.contains("type")) {
if (type == "object") { throw std::runtime_error("Missing type in python tool");
auto properties = parameters.at("properties"); }
for (auto it = properties.begin(); it != properties.end(); ++it) { has_raw_python = true;
if (it.value().at("type") == "string") { const auto & type = parameters.at("type");
if (!python_code_argument_name.empty()) { if (type == "object") {
throw std::runtime_error("Multiple string arguments found in python tool"); auto properties = parameters.at("properties");
for (auto it = properties.begin(); it != properties.end(); ++it) {
if (it.value().at("type") == "string") {
if (!python_code_argument_name.empty()) {
throw std::runtime_error("Multiple string arguments found in python tool");
}
python_code_argument_name = it.key();
} }
python_code_argument_name = it.key();
} }
if (python_code_argument_name.empty()) {
throw std::runtime_error("No string argument found in python tool");
}
} else if (type != "string") {
throw std::runtime_error("Invalid type in python tool: " + type.dump());
} }
if (python_code_argument_name.empty()) {
throw std::runtime_error("No string argument found in python tool");
}
} else if (type != "string") {
throw std::runtime_error("Invalid type in python tool: " + type.dump());
} }
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
});
if (has_raw_python) {
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
data.preserved_tokens.push_back("<|python_tag|>");
} }
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space")); auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
}); });
if (has_raw_python) { data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); } else {
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"}); data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
data.preserved_tokens.push_back("<|python_tag|>"); }
}
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
});
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
// TODO: if (has_raw_python) // TODO: if (has_raw_python)
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
return data; return data;
} }
static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) { static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
@ -1593,6 +1612,7 @@ static common_chat_params common_chat_templates_apply_jinja(
params.extract_reasoning = inputs.extract_reasoning; params.extract_reasoning = inputs.extract_reasoning;
params.tool_choice = inputs.tool_choice; params.tool_choice = inputs.tool_choice;
params.grammar = inputs.grammar; params.grammar = inputs.grammar;
params.now = inputs.now;
if (!inputs.json_schema.empty()) { if (!inputs.json_schema.empty()) {
params.json_schema = json::parse(inputs.json_schema); params.json_schema = json::parse(inputs.json_schema);
} }
@ -1644,21 +1664,21 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_firefunction_v2(tmpl, params); return common_chat_params_init_firefunction_v2(tmpl, params);
} }
// Plain handler (no tools)
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return common_chat_params_init_without_tools(tmpl, params);
}
// Functionary v3.1 (w/ tools) // Functionary v3.1 (w/ tools)
if (src.find("<|start_header_id|>") != std::string::npos if (src.find("<|start_header_id|>") != std::string::npos
&& src.find("<function=") != std::string::npos) { && src.find("<function=") != std::string::npos) {
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params); return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
} }
// Llama 3.1, 3.2, 3.3 (w/ tools) // Llama 3.1, 3.2, 3.3 (also requires date_string so using it even w/o tools)
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) { if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos; auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools); return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
}
// Plain handler (no tools)
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return common_chat_params_init_without_tools(tmpl, params);
} }
// Mistral Nemo (w/ tools) // Mistral Nemo (w/ tools)

View file

@ -3,6 +3,7 @@
#pragma once #pragma once
#include "common.h" #include "common.h"
#include <chrono>
#include <string> #include <string>
#include <vector> #include <vector>
@ -71,6 +72,7 @@ struct common_chat_templates_inputs {
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
bool parallel_tool_calls = false; bool parallel_tool_calls = false;
bool extract_reasoning = true; bool extract_reasoning = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
}; };
struct common_chat_params { struct common_chat_params {

View file

@ -451,6 +451,25 @@ void string_replace_all(std::string & s, const std::string & search, const std::
s = std::move(builder); s = std::move(builder);
} }
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
}
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
if (!str.empty() && !stop.empty()) {
const char text_last_char = str.back();
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
if (stop[char_index] == text_last_char) {
const auto current_partial = stop.substr(0, char_index + 1);
if (string_ends_with(str, current_partial)) {
return str.size() - char_index - 1;
}
}
}
}
return std::string::npos;
}
std::string regex_escape(const std::string & s) { std::string regex_escape(const std::string & s) {
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]"); static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
return std::regex_replace(s, special_chars, "\\$0"); return std::regex_replace(s, special_chars, "\\$0");

View file

@ -6,6 +6,7 @@
#include <set> #include <set>
#include <string> #include <string>
#include <string_view>
#include <vector> #include <vector>
#include <sstream> #include <sstream>
@ -499,10 +500,9 @@ static bool string_starts_with(const std::string & str,
return str.rfind(prefix, 0) == 0; return str.rfind(prefix, 0) == 0;
} }
static bool string_ends_with(const std::string & str, // While we wait for C++20's std::string::ends_with...
const std::string & suffix) { // While we wait for C++20's std::string::ends_with... bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
}
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides); bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
void string_process_escapes(std::string & input); void string_process_escapes(std::string & input);

View file

@ -13,10 +13,12 @@
#include <chrono> #include <chrono>
#include <cstddef> #include <cstddef>
#include <cstdio> #include <cstdio>
#include <ctime>
#include <exception> #include <exception>
#include <iomanip> #include <iomanip>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <stdexcept>
#include <string> #include <string>
#include <vector> #include <vector>
@ -393,8 +395,8 @@ class chat_template {
for (const auto & message_ : adjusted_messages) { for (const auto & message_ : adjusted_messages) {
auto message = message_; auto message = message_;
if (!message.contains("role") || !message.contains("content")) { if (!message.contains("role") || (!message.contains("content") && !message.contains("tool_calls"))) {
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); throw std::runtime_error("message must have 'role' and one of 'content' or 'tool_calls' fields: " + message.dump());
} }
std::string role = message.at("role"); std::string role = message.at("role");
@ -415,7 +417,6 @@ class chat_template {
} }
} }
if (polyfill_tool_calls) { if (polyfill_tool_calls) {
auto content = message.at("content");
auto tool_calls = json::array(); auto tool_calls = json::array();
for (const auto & tool_call : message.at("tool_calls")) { for (const auto & tool_call : message.at("tool_calls")) {
if (tool_call.at("type") != "function") { if (tool_call.at("type") != "function") {
@ -434,8 +435,11 @@ class chat_template {
auto obj = json { auto obj = json {
{"tool_calls", tool_calls}, {"tool_calls", tool_calls},
}; };
if (!content.is_null() && !content.empty()) { if (message.contains("content")) {
obj["content"] = content; auto content = message.at("content");
if (!content.is_null() && !content.empty()) {
obj["content"] = content;
}
} }
message["content"] = obj.dump(2); message["content"] = obj.dump(2);
message.erase("tool_calls"); message.erase("tool_calls");

View file

@ -11,6 +11,7 @@
#include <algorithm> #include <algorithm>
#include <cctype> #include <cctype>
#include <cstddef> #include <cstddef>
#include <cstdint>
#include <cmath> #include <cmath>
#include <exception> #include <exception>
#include <functional> #include <functional>
@ -233,7 +234,7 @@ public:
} }
} else if (is_object()) { } else if (is_object()) {
if (!index.is_hashable()) if (!index.is_hashable())
throw std::runtime_error("Unashable type: " + index.dump()); throw std::runtime_error("Unhashable type: " + index.dump());
auto it = object_->find(index.primitive_); auto it = object_->find(index.primitive_);
if (it == object_->end()) if (it == object_->end())
throw std::runtime_error("Key not found: " + index.dump()); throw std::runtime_error("Key not found: " + index.dump());
@ -252,7 +253,7 @@ public:
auto index = key.get<int>(); auto index = key.get<int>();
return array_->at(index < 0 ? array_->size() + index : index); return array_->at(index < 0 ? array_->size() + index : index);
} else if (object_) { } else if (object_) {
if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
auto it = object_->find(key.primitive_); auto it = object_->find(key.primitive_);
if (it == object_->end()) return Value(); if (it == object_->end()) return Value();
return it->second; return it->second;
@ -261,7 +262,7 @@ public:
} }
void set(const Value& key, const Value& value) { void set(const Value& key, const Value& value) {
if (!object_) throw std::runtime_error("Value is not an object: " + dump()); if (!object_) throw std::runtime_error("Value is not an object: " + dump());
if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
(*object_)[key.primitive_] = value; (*object_)[key.primitive_] = value;
} }
Value call(const std::shared_ptr<Context> & context, ArgumentsValue & args) const { Value call(const std::shared_ptr<Context> & context, ArgumentsValue & args) const {
@ -398,7 +399,7 @@ public:
} }
return false; return false;
} else if (object_) { } else if (object_) {
if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump()); if (!value.is_hashable()) throw std::runtime_error("Unhashable type: " + value.dump());
return object_->find(value.primitive_) != object_->end(); return object_->find(value.primitive_) != object_->end();
} else { } else {
throw std::runtime_error("contains can only be called on arrays and objects: " + dump()); throw std::runtime_error("contains can only be called on arrays and objects: " + dump());
@ -416,7 +417,7 @@ public:
return const_cast<Value*>(this)->at(index); return const_cast<Value*>(this)->at(index);
} }
Value& at(const Value & index) { Value& at(const Value & index) {
if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); if (!index.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
if (is_array()) return array_->at(index.get<int>()); if (is_array()) return array_->at(index.get<int>());
if (is_object()) return object_->at(index.primitive_); if (is_object()) return object_->at(index.primitive_);
throw std::runtime_error("Value is not an array or object: " + dump()); throw std::runtime_error("Value is not an array or object: " + dump());
@ -676,8 +677,8 @@ public:
class VariableExpr : public Expression { class VariableExpr : public Expression {
std::string name; std::string name;
public: public:
VariableExpr(const Location & location, const std::string& n) VariableExpr(const Location & loc, const std::string& n)
: Expression(location), name(n) {} : Expression(loc), name(n) {}
std::string get_name() const { return name; } std::string get_name() const { return name; }
Value do_evaluate(const std::shared_ptr<Context> & context) const override { Value do_evaluate(const std::shared_ptr<Context> & context) const override {
if (!context->contains(name)) { if (!context->contains(name)) {
@ -1200,9 +1201,9 @@ public:
class SliceExpr : public Expression { class SliceExpr : public Expression {
public: public:
std::shared_ptr<Expression> start, end; std::shared_ptr<Expression> start, end, step;
SliceExpr(const Location & loc, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e) SliceExpr(const Location & loc, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e, std::shared_ptr<Expression> && st = nullptr)
: Expression(loc), start(std::move(s)), end(std::move(e)) {} : Expression(loc), start(std::move(s)), end(std::move(e)), step(std::move(st)) {}
Value do_evaluate(const std::shared_ptr<Context> &) const override { Value do_evaluate(const std::shared_ptr<Context> &) const override {
throw std::runtime_error("SliceExpr not implemented"); throw std::runtime_error("SliceExpr not implemented");
} }
@ -1219,18 +1220,35 @@ public:
if (!index) throw std::runtime_error("SubscriptExpr.index is null"); if (!index) throw std::runtime_error("SubscriptExpr.index is null");
auto target_value = base->evaluate(context); auto target_value = base->evaluate(context);
if (auto slice = dynamic_cast<SliceExpr*>(index.get())) { if (auto slice = dynamic_cast<SliceExpr*>(index.get())) {
auto start = slice->start ? slice->start->evaluate(context).get<int64_t>() : 0; auto len = target_value.size();
auto end = slice->end ? slice->end->evaluate(context).get<int64_t>() : (int64_t) target_value.size(); auto wrap = [len](int64_t i) -> int64_t {
if (i < 0) {
return i + len;
}
return i;
};
int64_t step = slice->step ? slice->step->evaluate(context).get<int64_t>() : 1;
if (!step) {
throw std::runtime_error("slice step cannot be zero");
}
int64_t start = slice->start ? wrap(slice->start->evaluate(context).get<int64_t>()) : (step < 0 ? len - 1 : 0);
int64_t end = slice->end ? wrap(slice->end->evaluate(context).get<int64_t>()) : (step < 0 ? -1 : len);
if (target_value.is_string()) { if (target_value.is_string()) {
std::string s = target_value.get<std::string>(); std::string s = target_value.get<std::string>();
if (start < 0) start = s.size() + start;
if (end < 0) end = s.size() + end; std::string result;
return s.substr(start, end - start); if (start < end && step == 1) {
result = s.substr(start, end - start);
} else {
for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
result += s[i];
}
}
return result;
} else if (target_value.is_array()) { } else if (target_value.is_array()) {
if (start < 0) start = target_value.size() + start;
if (end < 0) end = target_value.size() + end;
auto result = Value::array(); auto result = Value::array();
for (auto i = start; i < end; ++i) { for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
result.push_back(target_value.at(i)); result.push_back(target_value.at(i));
} }
return result; return result;
@ -1305,6 +1323,8 @@ public:
if (name == "iterable") return l.is_iterable(); if (name == "iterable") return l.is_iterable();
if (name == "sequence") return l.is_array(); if (name == "sequence") return l.is_array();
if (name == "defined") return !l.is_null(); if (name == "defined") return !l.is_null();
if (name == "true") return l.to_bool();
if (name == "false") return !l.to_bool();
throw std::runtime_error("Unknown type for 'is' operator: " + name); throw std::runtime_error("Unknown type for 'is' operator: " + name);
}; };
auto value = eval(); auto value = eval();
@ -1520,6 +1540,10 @@ public:
vargs.expectArgs("endswith method", {1, 1}, {0, 0}); vargs.expectArgs("endswith method", {1, 1}, {0, 0});
auto suffix = vargs.args[0].get<std::string>(); auto suffix = vargs.args[0].get<std::string>();
return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
} else if (method->get_name() == "startswith") {
vargs.expectArgs("startswith method", {1, 1}, {0, 0});
auto prefix = vargs.args[0].get<std::string>();
return prefix.length() <= str.length() && std::equal(prefix.begin(), prefix.end(), str.begin());
} else if (method->get_name() == "title") { } else if (method->get_name() == "title") {
vargs.expectArgs("title method", {0, 0}, {0, 0}); vargs.expectArgs("title method", {0, 0}, {0, 0});
auto res = str; auto res = str;
@ -2082,28 +2106,37 @@ private:
while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) { while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
if (!consumeToken("[").empty()) { if (!consumeToken("[").empty()) {
std::shared_ptr<Expression> index; std::shared_ptr<Expression> index;
auto slice_loc = get_location();
std::shared_ptr<Expression> start, end, step;
bool has_first_colon = false, has_second_colon = false;
if (!peekSymbols({ ":" })) {
start = parseExpression();
}
if (!consumeToken(":").empty()) {
has_first_colon = true;
if (!peekSymbols({ ":", "]" })) {
end = parseExpression();
}
if (!consumeToken(":").empty()) { if (!consumeToken(":").empty()) {
auto slice_end = parseExpression(); has_second_colon = true;
index = std::make_shared<SliceExpr>(slice_end->location, nullptr, std::move(slice_end)); if (!peekSymbols({ "]" })) {
} else { step = parseExpression();
auto slice_start = parseExpression();
if (!consumeToken(":").empty()) {
consumeSpaces();
if (peekSymbols({ "]" })) {
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), nullptr);
} else {
auto slice_end = parseExpression();
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), std::move(slice_end));
}
} else {
index = std::move(slice_start);
} }
} }
if (!index) throw std::runtime_error("Empty index in subscript"); }
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index)); if ((has_first_colon || has_second_colon) && (start || end || step)) {
index = std::make_shared<SliceExpr>(slice_loc, std::move(start), std::move(end), std::move(step));
} else {
index = std::move(start);
}
if (!index) throw std::runtime_error("Empty index in subscript");
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
} else if (!consumeToken(".").empty()) { } else if (!consumeToken(".").empty()) {
auto identifier = parseIdentifier(); auto identifier = parseIdentifier();
if (!identifier) throw std::runtime_error("Expected identifier in subscript"); if (!identifier) throw std::runtime_error("Expected identifier in subscript");

204
common/regex-partial.cpp Normal file
View file

@ -0,0 +1,204 @@
#include "regex-partial.h"
#include "common.h"
#include <functional>
#include <optional>
common_regex::common_regex(const std::string & pattern) :
pattern(pattern),
rx(pattern),
rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {}
common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const {
std::smatch match;
if (pos > input.size()) {
throw std::runtime_error("Position out of bounds");
}
auto start = input.begin() + pos;
auto found = as_match
? std::regex_match(start, input.end(), match, rx)
: std::regex_search(start, input.end(), match, rx);
if (found) {
common_regex_match res;
res.type = COMMON_REGEX_MATCH_TYPE_FULL;
for (size_t i = 0; i < match.size(); ++i) {
auto begin = pos + match.position(i);
res.groups.emplace_back(begin, begin + match.length(i));
}
return res;
}
std::match_results<std::string::const_reverse_iterator> srmatch;
if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) {
auto group = srmatch[1].str();
if (group.length() != 0) {
auto it = srmatch[1].second.base();
// auto position = static_cast<size_t>(std::distance(input.begin(), it));
if ((!as_match) || it == input.begin()) {
common_regex_match res;
res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
const size_t begin = std::distance(input.begin(), it);
const size_t end = input.size();
if (begin == std::string::npos || end == std::string::npos || begin > end) {
throw std::runtime_error("Invalid range");
}
res.groups.push_back({begin, end});
return res;
}
}
}
return {};
}
/*
Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern.
Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html)
to see if a string ends with a partial regex match, but but it's not in std::regex yet.
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
- /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).*
- /a|b/ -> (a|b).*
- /a*?/ -> error, could match ""
- /a*b/ -> ((?:b)?a*+).* (final repetitions become eager)
- /.*?ab/ -> ((?:b)?a).* (merge .*)
- /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches)
- /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).*
- /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).*
- /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).*
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern
(i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored)
*/
std::string regex_to_reversed_partial_regex(const std::string & pattern) {
auto it = pattern.begin();
const auto end = pattern.end();
std::function<std::string()> process = [&]() {
std::vector<std::vector<std::string>> alternatives(1);
std::vector<std::string> * sequence = &alternatives.back();
while (it != end) {
if (*it == '[') {
auto start = it;
++it;
while (it != end) {
if ((*it == '\\') && (++it != end)) {
++it;
} else if ((it != end) && (*it == ']')) {
break;
} else {
++it;
}
}
if (it == end) {
throw std::runtime_error("Unmatched '[' in pattern");
}
++it;
sequence->push_back(std::string(start, it));
} else if (*it == '*' || *it == '?' || *it == '+') {
if (sequence->empty()) {
throw std::runtime_error("Quantifier without preceding element");
}
sequence->back() += *it;
auto is_star = *it == '*';
++it;
if (is_star) {
if (*it == '?') {
++it;
}
}
} else if (*it == '{') {
if (sequence->empty()) {
throw std::runtime_error("Repetition without preceding element");
}
++it;
auto start = it;
while (it != end && *it != '}') {
++it;
}
if (it == end) {
throw std::runtime_error("Unmatched '{' in pattern");
}
auto parts = string_split(std::string(start, it), ",");
++it;
if (parts.size() > 2) {
throw std::runtime_error("Invalid repetition range in pattern");
}
auto parseOptInt = [&](const std::string & s, const std::optional<int> & def = std::nullopt) -> std::optional<int> {
if (s.empty()) {
return def;
}
return std::stoi(s);
};
auto min = parseOptInt(parts[0], 0);
auto max = parts.size() == 1 ? min : parseOptInt(parts[1]);
if (min && max && *max < *min) {
throw std::runtime_error("Invalid repetition range in pattern");
}
// Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded)
auto part = sequence->back();
sequence->pop_back();
for (int i = 0; i < *min; i++) {
sequence->push_back(part);
}
if (max) {
for (int i = *min; i < *max; i++) {
sequence->push_back(part + "?");
}
} else {
sequence->push_back(part + "*");
}
} else if (*it == '(') {
++it;
if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') {
it += 2;
}
auto sub = process();
if (*it != ')') {
throw std::runtime_error("Unmatched '(' in pattern");
}
++it;
auto & part = sequence->emplace_back("(?:");
part += sub;
part += ")";
} else if (*it == ')') {
break;
} else if (*it == '|') {
++it;
alternatives.emplace_back();
sequence = &alternatives.back();
} else if (*it == '\\' && (++it != end)) {
auto str = std::string("\\") + *it;
sequence->push_back(str);
++it;
} else if (it != end) {
sequence->push_back(std::string(1, *it));
++it;
}
}
// /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).*
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
// We'll do the outermost capturing group and final .* in the enclosing function.
std::vector<std::string> res_alts;
for (const auto & parts : alternatives) {
auto & res = res_alts.emplace_back();
for (size_t i = 0; i < parts.size() - 1; i++) {
res += "(?:";
}
for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
res += *it;
if (it != parts.rend() - 1) {
res += ")?";
}
}
}
return string_join(res_alts, "|");
};
auto res = process();
if (it != end) {
throw std::runtime_error("Unmatched '(' in pattern");
}
return "(" + res + ")[\\s\\S]*";
}

56
common/regex-partial.h Normal file
View file

@ -0,0 +1,56 @@
#pragma once
#include <regex>
#include <string>
enum common_regex_match_type {
COMMON_REGEX_MATCH_TYPE_NONE,
COMMON_REGEX_MATCH_TYPE_PARTIAL,
COMMON_REGEX_MATCH_TYPE_FULL,
};
struct common_string_range {
size_t begin;
size_t end;
common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
if (begin > end) {
throw std::runtime_error("Invalid range");
}
}
// prevent default ctor
common_string_range() = delete;
bool empty() const {
return begin == end;
}
bool operator==(const common_string_range & other) const {
return begin == other.begin && end == other.end;
}
};
struct common_regex_match {
common_regex_match_type type = COMMON_REGEX_MATCH_TYPE_NONE;
std::vector<common_string_range> groups;
bool operator==(const common_regex_match & other) const {
return type == other.type && groups == other.groups;
}
bool operator!=(const common_regex_match & other) const {
return !(*this == other);
}
};
class common_regex {
std::string pattern;
std::regex rx;
std::regex rx_reversed_partial;
public:
explicit common_regex(const std::string & pattern);
common_regex_match search(const std::string & input, size_t pos, bool as_match = false) const;
const std::string & str() const { return pattern; }
};
// For testing only (pretty print of failures).
std::string regex_to_reversed_partial_regex(const std::string & pattern);

View file

@ -2069,6 +2069,9 @@ class Llama4Model(LlamaModel):
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"]) self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
if name.startswith("language_model."):
name = name.replace("language_model.", "")
# split the gate_up into gate and up # split the gate_up into gate and up
if "gate_up_proj" in name: if "gate_up_proj" in name:
name_up = name.replace("gate_up_proj", "up_proj.weight") name_up = name.replace("gate_up_proj", "up_proj.weight")

View file

@ -31,7 +31,7 @@ llama-server -hf ggml-org/gemma-3-4b-it-GGUF --no-mmproj-offload
## Pre-quantized models ## Pre-quantized models
These are ready-to-use models, most of them come with `Q4_K_M` quantization by default. These are ready-to-use models, most of them come with `Q4_K_M` quantization by default. They can be found at the Hugging Face page of the ggml-org: https://huggingface.co/ggml-org
Replaces the `(tool_name)` with the name of binary you want to use. For example, `llama-mtmd-cli` or `llama-server` Replaces the `(tool_name)` with the name of binary you want to use. For example, `llama-mtmd-cli` or `llama-server`

View file

@ -8520,7 +8520,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0); assert(n % QK_K == 0);
#ifdef __ARM_FEATURE_MATMUL_INT8
assert((nrc == 2) || (nrc == 1));
#else
assert(nrc == 1); assert(nrc == 1);
#endif
UNUSED(nrc); UNUSED(nrc);
UNUSED(bx); UNUSED(bx);
UNUSED(by); UNUSED(by);
@ -8531,6 +8535,197 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const int nb = n / QK_K; const int nb = n / QK_K;
#if defined(__ARM_FEATURE_MATMUL_INT8)
if (nrc == 2) {
const block_q6_K * GGML_RESTRICT x0 = x;
const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
const block_q8_K * GGML_RESTRICT y0 = y;
const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
float32x4_t vfsum = vdupq_n_f32(0.0f);
for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
const uint8_t * GGML_RESTRICT ql0 = x0->ql;
const uint8_t * GGML_RESTRICT ql1 = x1->ql;
const uint8_t * GGML_RESTRICT qh0 = x0->qh;
const uint8_t * GGML_RESTRICT qh1 = x1->qh;
const int8_t * GGML_RESTRICT qy0 = y0->qs;
const int8_t * GGML_RESTRICT qy1 = y1->qs;
const uint8x16_t mone = vdupq_n_u8(0x30);
const uint8x16_t m4b = vdupq_n_u8(0x0f);
int32x4_t visum = vdupq_n_s32(0);
// process 8 blocks per iteration, totally 16 blocks
for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) {
int8x16_t vx0[8], vx1[8];
// de-quantize vx0[8]
{
const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0);
const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0);
uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
q6h_0 = vandq_u8(mone, qh_bits.val[0]);
q6h_1 = vandq_u8(mone, qh_bits.val[1]);
q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
}
// de-quantize vx1[8]
{
const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1);
const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1);
uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
q6h_0 = vandq_u8(mone, qh_bits.val[0]);
q6h_1 = vandq_u8(mone, qh_bits.val[1]);
q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
}
// process 16 elements (one block with same scale) per iteration
// - vx = concat(ql, qh) - 32
// - r1,r2,r3,r4 = smmla(vx, vy)
for (int k = 0; k < 8; ++k) {
const int blk = j * 8 + k;
const int8x16_t vy0 = vld1q_s8(qy0);
const int8x16_t vy1 = vld1q_s8(qy1);
qy0 += 16;
qy1 += 16;
const int32x4_t block_scale = {
x0->scales[blk],
x0->scales[blk],
x1->scales[blk],
x1->scales[blk],
};
// calculate four results at once with outer product
const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
int32x4_t vr = vdupq_n_s32(0);
vr = vmmlaq_s32(vr, vx_l, vy_l);
vr = vmmlaq_s32(vr, vx_h, vy_h);
// apply block scale, will NOT overflow
// block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits
visum = vmlaq_s32(visum, vr, block_scale);
}
}
// adjust bias, apply superblock scale
{
int32_t bias[4];
#ifdef __ARM_FEATURE_SVE
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
const svint64_t zero = svdup_n_s64(0);
bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
#else
// NEON doesn't support int16 dot product, fallback to separated mul and add
const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
int8x16_t scales_s8 = vld1q_s8(x0->scales);
const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
scales_s8 = vld1q_s8(x1->scales);
const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
int32x4_t prod;
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])),
vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))),
vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])),
vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1]))));
bias[0] = vaddvq_s32(prod);
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])),
vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))),
vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])),
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1]))));
bias[1] = vaddvq_s32(prod);
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])),
vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))),
vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])),
vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1]))));
bias[2] = vaddvq_s32(prod);
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])),
vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))),
vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])),
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
bias[3] = vaddvq_s32(prod);
#endif
const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
const float32x4_t superblock_scale = {
GGML_FP16_TO_FP32(x0->d) * y0->d,
GGML_FP16_TO_FP32(x0->d) * y1->d,
GGML_FP16_TO_FP32(x1->d) * y0->d,
GGML_FP16_TO_FP32(x1->d) * y1->d,
};
visum = vsubq_s32(visum, vibias);
vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
}
}
// vfsum = ABCD -> ACBD
// AC -> s, BD -> (s+bs)
vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
vst1_f32(s, vget_low_f32 (vfsum));
vst1_f32(s + bs, vget_high_f32(vfsum));
return;
}
#endif
#ifdef __ARM_FEATURE_SVE #ifdef __ARM_FEATURE_SVE
const int vector_length = ggml_cpu_get_sve_cnt()*8; const int vector_length = ggml_cpu_get_sve_cnt()*8;
float sum = 0; float sum = 0;

View file

@ -286,7 +286,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q6_K, .from_float = quantize_row_q6_K,
.vec_dot = ggml_vec_dot_q6_K_q8_K, .vec_dot = ggml_vec_dot_q6_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
#if defined (__ARM_FEATURE_MATMUL_INT8)
.nrows = 2,
#else
.nrows = 1, .nrows = 1,
#endif
}, },
[GGML_TYPE_IQ2_XXS] = { [GGML_TYPE_IQ2_XXS] = {
.from_float = NULL, .from_float = NULL,

View file

@ -678,10 +678,14 @@ void launch_fattn(
) { ) {
constexpr int ncols = ncols1 * ncols2; constexpr int ncols = ncols1 * ncols2;
const bool is_mla = DV == 512; // TODO better parameterization
const ggml_tensor * Q = dst->src[0]; const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1]; const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2]; const ggml_tensor * V = dst->src[2];
GGML_ASSERT(V || is_mla);
const ggml_tensor * mask = dst->src[3]; const ggml_tensor * mask = dst->src[3];
ggml_tensor * KQV = dst; ggml_tensor * KQV = dst;
@ -689,6 +693,10 @@ void launch_fattn(
GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(Q->type == GGML_TYPE_F32);
GGML_ASSERT(KQV->type == GGML_TYPE_F32); GGML_ASSERT(KQV->type == GGML_TYPE_F32);
GGML_ASSERT( Q->nb[0] == ggml_element_size(Q));
GGML_ASSERT( K->nb[0] == ggml_element_size(K));
GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
@ -713,10 +721,10 @@ void launch_fattn(
size_t nb12 = K->nb[2]; size_t nb12 = K->nb[2];
size_t nb13 = K->nb[3]; size_t nb13 = K->nb[3];
const char * V_data = (const char *) V->data; const char * V_data = V ? (const char *) V->data : nullptr;
size_t nb21 = V->nb[1]; size_t nb21 = V ? V->nb[1] : nb11;
size_t nb22 = V->nb[2]; size_t nb22 = V ? V->nb[2] : nb12;
size_t nb23 = V->nb[3]; size_t nb23 = V ? V->nb[3] : nb13;
if (need_f16_K && K->type != GGML_TYPE_F16) { if (need_f16_K && K->type != GGML_TYPE_F16) {
GGML_ASSERT(ggml_is_contiguously_allocated(K)); GGML_ASSERT(ggml_is_contiguously_allocated(K));
@ -733,7 +741,7 @@ void launch_fattn(
nb13 = nb13*bs*sizeof(half)/ts; nb13 = nb13*bs*sizeof(half)/ts;
} }
if (need_f16_V && V->type != GGML_TYPE_F16) { if (V && need_f16_V && V->type != GGML_TYPE_F16) {
GGML_ASSERT(ggml_is_contiguously_allocated(V)); GGML_ASSERT(ggml_is_contiguously_allocated(V));
V_f16.alloc(ggml_nelements(V)); V_f16.alloc(ggml_nelements(V));
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);

View file

@ -33,9 +33,30 @@ struct fattn_mma_f16_config< 64, 64> {
static constexpr int nwarps_max = 4; static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true; static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2; static constexpr int nstages_target = 2;
static constexpr int nbatch_K2 = 32;
static constexpr int nbatch_V2 = 32; static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
static constexpr int nbatch_combine = 32; return 32;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 32;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 32;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 32;
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 32;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 32;
}
}; };
template <> template <>
@ -44,9 +65,30 @@ struct fattn_mma_f16_config< 80, 80> {
static constexpr int nwarps_max = 4; static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true; static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2; static constexpr int nstages_target = 2;
static constexpr int nbatch_K2 = 40;
static constexpr int nbatch_V2 = 40; static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
static constexpr int nbatch_combine = 40; return 40;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 40;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 40;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 40;
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 40;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 40;
}
}; };
template <> template <>
@ -55,9 +97,30 @@ struct fattn_mma_f16_config< 96, 96> {
static constexpr int nwarps_max = 4; static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true; static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2; static constexpr int nstages_target = 2;
static constexpr int nbatch_K2 = 48;
static constexpr int nbatch_V2 = 48; static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
static constexpr int nbatch_combine = 48; return 48;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 48;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 48;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 48;
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 48;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 48;
}
}; };
template <> template <>
@ -66,9 +129,30 @@ struct fattn_mma_f16_config<112, 112> {
static constexpr int nwarps_max = 4; static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true; static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2; static constexpr int nstages_target = 2;
static constexpr int nbatch_K2 = 56;
static constexpr int nbatch_V2 = 56; static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
static constexpr int nbatch_combine = 56; return 56;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 56;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 56;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 56;
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 56;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 56;
}
}; };
template <> template <>
@ -77,9 +161,30 @@ struct fattn_mma_f16_config<128, 128> {
static constexpr int nwarps_max = 4; static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true; static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2; static constexpr int nstages_target = 2;
static constexpr int nbatch_K2 = 64;
static constexpr int nbatch_V2 = 64; static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
static constexpr int nbatch_combine = 64; return 64;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 64;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 64;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 64;
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 64;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 64;
}
}; };
template <> template <>
@ -88,9 +193,38 @@ struct fattn_mma_f16_config<256, 256> {
static constexpr int nwarps_max = 4; static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true; static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2; static constexpr int nstages_target = 2;
static constexpr int nbatch_K2 = 128;
static constexpr int nbatch_V2 = 128; static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
static constexpr int nbatch_combine = 128; return 128;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 128;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 128;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 128;
}
static int get_nbatch_combine_host(const int cc, const int ncols) {
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
return ncols <= 16 ? 128 : 64;
}
return 64;
}
static constexpr __device__ int get_nbatch_combine_device(int ncols) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
return ncols <= 16 ? 128 : 64;
#else
GGML_UNUSED(ncols);
return 128;
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
}
}; };
template <> template <>
@ -99,9 +233,44 @@ struct fattn_mma_f16_config<576, 512> {
static constexpr int nwarps_max = 8; static constexpr int nwarps_max = 8;
static constexpr bool Q_in_reg = false; static constexpr bool Q_in_reg = false;
static constexpr int nstages_target = 1; static constexpr int nstages_target = 1;
static constexpr int nbatch_K2 = 160;
static constexpr int nbatch_V2 = 128; static int get_nbatch_K2_host(const int cc, const int ncols) {
static constexpr int nbatch_combine = 128; if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
return ncols <= 16 ? 96 : 160;
}
return ncols <= 16 ? 288 : 160;
}
static constexpr __device__ int get_nbatch_K2_device(int ncols) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
return ncols <= 16 ? 96 : 160;
#else
return ncols <= 16 ? 288 : 160;
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
}
static int get_nbatch_V2_host(const int cc, const int ncols) {
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
return ncols <= 16 ? 64 : 128;
}
return ncols <= 16 ? 256 : 128;
}
static constexpr __device__ int get_nbatch_V2_device(int ncols) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
return ncols <= 16 ? 64 : 128;
#else
return ncols <= 16 ? 256 : 128;
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 128;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 128;
}
}; };
// ------------------------------------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------------------------------------
@ -120,7 +289,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
auto load = [&] __device__ (const int n) { auto load = [&] __device__ (auto n) {
const int stride_k = WARP_SIZE >> n; const int stride_k = WARP_SIZE >> n;
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
@ -223,7 +392,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
} }
} }
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter> template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
static __device__ __forceinline__ void flash_attn_ext_f16_iter( static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const float2 * const __restrict__ Q_f2, const float2 * const __restrict__ Q_f2,
const half2 * const __restrict__ K_h2, const half2 * const __restrict__ K_h2,
@ -261,10 +430,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_warp = ntiles * tile_B::I;
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
constexpr int ncols = ncols1 * ncols2;
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
constexpr int stride_tile_Q = DKQ/2 + 4; constexpr int stride_tile_Q = DKQ/2 + 4;
constexpr int stride_tile_K = c::nbatch_K2 + 4; constexpr int stride_tile_K = nbatch_K2 + 4;
constexpr int stride_tile_V = c::nbatch_V2 + 4;
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
const int k_VKQ_0 = kb0 * c::nbatch_fa; const int k_VKQ_0 = kb0 * c::nbatch_fa;
tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles]; tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
@ -275,12 +449,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
if constexpr (nstages > 1) { if constexpr (nstages > 1) {
static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); static_assert(!mla, "multi-stage loading not implemented for MLA");
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
constexpr bool use_cp_async = true; constexpr bool use_cp_async = true;
cp_async_wait_all(); cp_async_wait_all();
__syncthreads(); __syncthreads();
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async> flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
(V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V); (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
} else { } else {
constexpr bool use_cp_async = nstages == 1; constexpr bool use_cp_async = nstages == 1;
if (ncols2 > 1 || mask_h2) { if (ncols2 > 1 || mask_h2) {
@ -289,8 +464,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
} }
#pragma unroll #pragma unroll
for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) { for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2; const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
const int k0_diff = k0_stop - k0_start; const int k0_diff = k0_stop - k0_start;
if (nstages <= 1) { if (nstages <= 1) {
@ -537,16 +712,21 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
} }
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async> flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K); (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
} }
} }
#pragma unroll
for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) {
const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV;
const int i0_diff = i0_stop - i0_start;
if (nstages <= 1) { // For MLA K and V have the same data.
// Therefore, iterate over V in reverse and re-use the data if possible.
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
#pragma unroll
for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
const int i0_diff = i0_stop - i0_start;
if (nstages <= 1 && i0_start < reusable_cutoff) {
constexpr bool use_cp_async = nstages == 1; constexpr bool use_cp_async = nstages == 1;
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async> flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
@ -555,6 +735,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
} }
__syncthreads(); __syncthreads();
} }
const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
// Calculate VKQ tile: // Calculate VKQ tile:
#pragma unroll #pragma unroll
@ -565,7 +746,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const int k0 = k00 + (threadIdx.y % np)*tile_A::J; const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
tile_A A; tile_A A;
load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
if (ntiles == 1) { if (ntiles == 1) {
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
} else { } else {
@ -596,7 +777,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
#endif // NEW_MMA_AVAILABLE #endif // NEW_MMA_AVAILABLE
} }
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup> template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const float2 * const __restrict__ Q_f2, const float2 * const __restrict__ Q_f2,
const half2 * const __restrict__ K_h2, const half2 * const __restrict__ K_h2,
@ -632,13 +813,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_warp = ntiles * tile_B::I;
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
constexpr int stride_tile_Q = DKQ/2 + 4; constexpr int stride_tile_Q = DKQ/2 + 4;
constexpr int stride_tile_K = c::nbatch_K2 + 4; constexpr int stride_tile_K = nbatch_K2 + 4;
constexpr int stride_tile_V = c::nbatch_V2 + 4;
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
extern __shared__ half2 tile_Q[]; extern __shared__ half2 tile_Q[];
@ -726,26 +910,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
// Preload mask and K data for first iteration when using cp_async with multiple stages: // Preload mask and K data for first iteration when using cp_async with multiple stages:
if constexpr (nstages > 1) { if constexpr (nstages > 1) {
static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
constexpr bool use_cp_async = true; constexpr bool use_cp_async = true;
if (ncols2 > 1 || mask_h2) { if (ncols2 > 1 || mask_h2) {
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async> flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
} }
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async> flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K); (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
} }
// Iterate over ne11 == previous tokens: // Iterate over ne11 == previous tokens:
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
constexpr bool last_iter = false; constexpr bool last_iter = false;
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter> flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
} }
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
constexpr bool last_iter = true; constexpr bool last_iter = true;
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter> flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
} }
@ -774,7 +958,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
// So also write VKQ accumulators to shared memory in column-major format if np == 1. // So also write VKQ accumulators to shared memory in column-major format if np == 1.
constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4; constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols);
constexpr int tile_stride = nbatch_combine + 4; constexpr int tile_stride = nbatch_combine + 4;
static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine"); static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
@ -1012,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
#endif // NEW_MMA_AVAILABLE #endif // NEW_MMA_AVAILABLE
} }
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap> template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
__launch_bounds__(nwarps*WARP_SIZE, 1) __launch_bounds__(nwarps*WARP_SIZE, 1)
static __global__ void flash_attn_ext_f16( static __global__ void flash_attn_ext_f16(
const char * __restrict__ Q, const char * __restrict__ Q,
@ -1057,6 +1241,14 @@ static __global__ void flash_attn_ext_f16(
NO_DEVICE_CODE; NO_DEVICE_CODE;
return; return;
} }
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
if (ncols1*ncols2 > 32) {
NO_DEVICE_CODE;
return;
}
#endif __CUDA_ARCH__ == GGML_CUDA_CC_TURING
static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
typedef fattn_mma_f16_config<DKQ, DV> c; typedef fattn_mma_f16_config<DKQ, DV> c;
@ -1067,9 +1259,10 @@ static __global__ void flash_attn_ext_f16(
const int stride_Q1 = nb01 / sizeof(float2); const int stride_Q1 = nb01 / sizeof(float2);
const int stride_Q2 = nb02 / sizeof(float2); const int stride_Q2 = nb02 / sizeof(float2);
const int stride_K = nb11 / sizeof(half2); const int stride_K = nb11 / sizeof(half2);
const int stride_V = nb21 / sizeof(half2);
const int stride_mask = nb31 / sizeof(half2); const int stride_mask = nb31 / sizeof(half2);
const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
const int iter_k = ne11 / FATTN_KQ_STRIDE; const int iter_k = ne11 / FATTN_KQ_STRIDE;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
@ -1092,10 +1285,11 @@ static __global__ void flash_attn_ext_f16(
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter; const int kb0_start_kernel = kb0_start * kb_niter;
@ -1104,12 +1298,12 @@ static __global__ void flash_attn_ext_f16(
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
if (kb0_start == 0) { if (kb0_start == 0) {
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup> flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
} else { } else {
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup> flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
} }
@ -1130,10 +1324,11 @@ static __global__ void flash_attn_ext_f16(
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter; const int kb0_start_kernel = kb0_start * kb_niter;
@ -1141,7 +1336,7 @@ static __global__ void flash_attn_ext_f16(
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
constexpr bool needs_fixup = false; constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup> flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
#else #else
@ -1167,10 +1362,6 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
typedef fattn_mma_f16_config<DKQ, DV> c; typedef fattn_mma_f16_config<DKQ, DV> c;
constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2;
constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2;
constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine;
const int nstages = cp_async_available(cc) ? c::nstages_target : 0; const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
constexpr int ncols = ncols1 * ncols2; constexpr int ncols = ncols1 * ncols2;
@ -1180,15 +1371,21 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I; constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max; constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
constexpr bool mla = DKQ == 576;
const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols);
const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols);
const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols);
static_assert(DKQ % tile_B::J == 0, "bad DKQ"); static_assert(DKQ % tile_B::J == 0, "bad DKQ");
static_assert(DV % tile_A::J == 0, "bad DV"); static_assert(DV % tile_A::J == 0, "bad DV");
static_assert(ncols % cols_per_warp == 0, "bad ncols"); static_assert(ncols % cols_per_warp == 0, "bad ncols");
const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2); const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2); const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage; const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
@ -1202,7 +1399,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
fattn_kernel_t fattn_kernel; fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) { if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false; constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>; fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
@ -1213,7 +1410,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
} else { } else {
constexpr bool use_logit_softcap = true; constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>; fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};

View file

@ -10,6 +10,7 @@
template <int DKQ, int DV, int ncols2> template <int DKQ, int DV, int ncols2>
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const ggml_tensor * Q = dst->src[0]; const ggml_tensor * Q = dst->src[0];
if constexpr (ncols2 <= 8) { if constexpr (ncols2 <= 8) {
@ -24,7 +25,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
return; return;
} }
if (Q->ne[1] <= 32/ncols2) { if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst); ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
return; return;
} }

View file

@ -3227,7 +3227,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
#endif // FLASH_ATTN_AVAILABLE #endif // FLASH_ATTN_AVAILABLE
if (op->src[1]->ne[0] != op->src[2]->ne[0]) { if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc; const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) { if (!new_mma_available(cc)) {
return false; return false;
} }
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2]; const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];

View file

@ -122,6 +122,7 @@ void ggml_cuda_mul_mat_q(
const int64_t s13 = src1->nb[3] / ts_src1; const int64_t s13 = src1->nb[3] / ts_src1;
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream); ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
CUDA_CHECK(cudaGetLastError());
} }
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int)); const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
@ -205,6 +206,7 @@ void ggml_cuda_mul_mat_q(
const int64_t s13 = src1->nb[2] / ts_src1; const int64_t s13 = src1->nb[2] / ts_src1;
quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type, quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
CUDA_CHECK(cudaGetLastError());
} }
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int)); const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));

View file

@ -56,13 +56,13 @@ static __global__ void quantize_mmq_q8_1(
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32; constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32; constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4; const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4;
if (i0 >= ne0) { if (i0 >= ne0) {
return; return;
} }
const int64_t i1 = blockIdx.y; const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.z % ne2; const int64_t i2 = blockIdx.z % ne2;
const int64_t i3 = blockIdx.z / ne2; const int64_t i3 = blockIdx.z / ne2;
@ -75,8 +75,8 @@ static __global__ void quantize_mmq_q8_1(
block_q8_1_mmq * y = (block_q8_1_mmq *) vy; block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.y; // block index in channel const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel
const int64_t iqs = i0 % (4*QK8_1); // quant index in block const int64_t iqs = i0 % (4*QK8_1); // quant index in block
// Load 4 floats per thread and calculate max. abs. value between them: // Load 4 floats per thread and calculate max. abs. value between them:
@ -166,8 +166,9 @@ void quantize_mmq_q8_1_cuda(
GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ne0 % (4*QK8_1) == 0); GGML_ASSERT(ne0 % (4*QK8_1) == 0);
const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); // ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid:
const dim3 num_blocks(block_num_x, ne1, ne2*ne3); const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
const dim3 num_blocks(ne1, block_num_y, ne2*ne3);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
switch (mmq_get_q8_1_ds_layout(type_src0)) { switch (mmq_get_q8_1_ds_layout(type_src0)) {
case MMQ_Q8_1_DS_LAYOUT_D4: case MMQ_Q8_1_DS_LAYOUT_D4:

View file

@ -56,6 +56,28 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> {
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
}; };
template <> struct block_q_t<GGML_TYPE_Q4_K> {
struct traits {
static constexpr uint32_t qk = QK_K;
static constexpr uint32_t qi = QI4_K;
static constexpr uint32_t qr = QR4_K;
static constexpr uint32_t vdr_mmvq = 2;
};
static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
auto nblocks = (nrows * (ncols / traits::qk));
return (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2));
}
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; }
constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; }
};
} // namespace ggml_sycl_reordered } // namespace ggml_sycl_reordered
#endif // GGML_SYCL_QUANTS_HPP #endif // GGML_SYCL_QUANTS_HPP

View file

@ -304,6 +304,9 @@ struct vk_device_struct {
bool coopmat_acc_f32_support {}; bool coopmat_acc_f32_support {};
bool coopmat_acc_f16_support {}; bool coopmat_acc_f16_support {};
bool coopmat_bf16_support {}; bool coopmat_bf16_support {};
bool coopmat_support_16x16x16_f16acc {};
bool coopmat_support_16x16x16_f32acc {};
bool coopmat1_fa_support {};
uint32_t coopmat_m; uint32_t coopmat_m;
uint32_t coopmat_n; uint32_t coopmat_n;
uint32_t coopmat_k; uint32_t coopmat_k;
@ -426,6 +429,13 @@ struct vk_device_struct {
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
@ -1604,19 +1614,36 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
); );
} }
enum FaCodePath {
FA_SCALAR,
FA_COOPMAT1,
FA_COOPMAT2,
};
// number of rows/cols for flash attention shader // number of rows/cols for flash attention shader
static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t flash_attention_num_small_rows = 32;
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
static constexpr uint32_t scalar_flash_attention_num_large_rows = 8; static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
static uint32_t get_fa_num_small_rows(bool scalar) { // The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows; // 128 threads split into four subgroups, each subgroup does 1/4
// of the Bc dimension.
static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16;
static constexpr uint32_t scalar_flash_attention_Bc = 64;
static constexpr uint32_t scalar_flash_attention_workgroup_size = 128;
static uint32_t get_fa_num_small_rows(FaCodePath path) {
if (path == FA_COOPMAT2) {
return flash_attention_num_small_rows;
} else {
return scalar_flash_attention_num_small_rows;
}
} }
static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
GGML_UNUSED(clamp); GGML_UNUSED(clamp);
if (scalar) { if (path == FA_SCALAR) {
if (small_rows) { if (small_rows) {
return {scalar_flash_attention_num_small_rows, 64}; return {scalar_flash_attention_num_small_rows, 64};
} else { } else {
@ -1624,9 +1651,17 @@ static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t cl
} }
} }
if (path == FA_COOPMAT1) {
if (small_rows) {
return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc};
} else {
return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc};
}
}
// small rows, large cols // small rows, large cols
if (small_rows) { if (small_rows) {
return {get_fa_num_small_rows(scalar), 32}; return {get_fa_num_small_rows(FA_COOPMAT2), 32};
} }
// small cols to reduce register count // small cols to reduce register count
@ -1923,17 +1958,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
}; };
auto const &fa_wg_denoms = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> { auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
return {fa_rows_cols(scalar, D, clamp, type, small_rows)[0], 1, 1}; return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1};
}; };
auto const &fa_spec_constants = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> { auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
// For large number of rows, 128 invocations seems to work best. // For large number of rows, 128 invocations seems to work best.
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
// can't use 256 for D==80. // can't use 256 for D==80.
// For scalar, use 128 (arbitrary) // For scalar, use 128 (arbitrary)
uint32_t wg_size = scalar ? 128 : ((small_rows && (D % 32) == 0) ? 256 : 128); uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
auto rows_cols = fa_rows_cols(scalar, D, clamp, type, small_rows); ? scalar_flash_attention_workgroup_size
: ((small_rows && (D % 32) == 0) ? 256 : 128);
auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows);
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
@ -1945,36 +1982,43 @@ static void ggml_vk_load_shaders(vk_device& device) {
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split}; return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
}; };
#define CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, D) \ #define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
#define CREATE_FA(TYPE, NAMELC, SCALAR, SUFFIX) \ #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 64) \ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 80) \ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 96) \ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 112) \ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 128) \ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 256) CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256)
CREATE_FA(GGML_TYPE_F16, f16, true, ) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, true, ) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, true, ) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->coopmat1_fa_support) {
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
}
#endif
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) { if (device->coopmat2) {
CREATE_FA(GGML_TYPE_F16, f16, false, _cm2) CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, false, _cm2) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, false, _cm2) CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, false, _cm2) CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, false, _cm2) CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, false, _cm2) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, false, _cm2) CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
} }
#endif #endif
#undef CREATE_FA2 #undef CREATE_FA2
@ -2057,17 +2101,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
// Create 6 variants, {s,m,l}x{unaligned,aligned} // Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \ if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
if (device->mul_mat ## ID ## _m[TYPE]) \ if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
if (device->mul_mat ## ID ## _s[TYPE]) \ if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
if (device->mul_mat ## ID ## _l[TYPE]) \ if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
if (device->mul_mat ## ID ## _m[TYPE]) \ if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
if (device->mul_mat ## ID ## _s[TYPE]) \ if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
// Create 2 variants, {f16,f32} accumulator // Create 2 variants, {f16,f32} accumulator
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@ -3033,6 +3077,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
#if defined(VK_KHR_cooperative_matrix) #if defined(VK_KHR_cooperative_matrix)
device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
// coopmat1 fa shader currently assumes 32 invocations per subgroup
device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support &&
device->subgroup_size_control && device->subgroup_min_size <= 32 &&
device->subgroup_max_size >= 32;
#endif #endif
if (coopmat2_support) { if (coopmat2_support) {
@ -3167,6 +3216,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
// Only enable if shape is identical // Only enable if shape is identical
device->coopmat_acc_f32_support = true; device->coopmat_acc_f32_support = true;
} }
if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
device->coopmat_support_16x16x16_f32acc = true;
}
} else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 && } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) { (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
// coopmat sizes not set yet // coopmat sizes not set yet
@ -3179,6 +3231,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
// Only enable if shape is identical // Only enable if shape is identical
device->coopmat_acc_f16_support = true; device->coopmat_acc_f16_support = true;
} }
if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
device->coopmat_support_16x16x16_f16acc = true;
}
} }
} else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 && } else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 &&
(vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 && (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 &&
@ -5712,6 +5767,36 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
} }
} }
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) {
// Needs to be kept up to date on shader changes
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
const uint32_t Br = scalar_flash_attention_num_large_rows;
const uint32_t Bc = scalar_flash_attention_Bc;
const uint32_t acctype = f32acc ? 4 : 2;
const uint32_t f16vec4 = 8;
const uint32_t tmpsh = wg_size * sizeof(float);
const uint32_t tmpshv4 = wg_size * 4 * acctype;
const uint32_t Qf = Br * (D / 4 + 2) * f16vec4;
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
const uint32_t sfsh = Bc * sfshstride * acctype;
const uint32_t kshstride = D / 4 + 2;
const uint32_t ksh = Bc * kshstride * f16vec4;
const uint32_t slope = Br * sizeof(float);
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
return supported;
}
static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) { static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3]; VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3]; std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
@ -5762,7 +5847,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
assert(q->type == GGML_TYPE_F32); assert(q->type == GGML_TYPE_F32);
assert(k->type == v->type); assert(k->type == v->type);
bool scalar = !ctx->device->coopmat2; FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 :
ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
if (path == FA_COOPMAT1) {
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32);
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
path = FA_SCALAR;
}
}
uint32_t gqa_ratio = 1; uint32_t gqa_ratio = 1;
uint32_t qk_ratio = neq2 / nek2; uint32_t qk_ratio = neq2 / nek2;
@ -5770,9 +5867,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_y = (uint32_t)neq2;
uint32_t workgroups_z = (uint32_t)neq3; uint32_t workgroups_z = (uint32_t)neq3;
// For scalar FA, we can use the "large" size to accommodate qga. // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
// For coopmat FA, we always use the small size (which is still pretty large for gqa). // For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false); uint32_t max_gqa;
switch (path) {
case FA_SCALAR:
case FA_COOPMAT1:
// We may switch from coopmat1 to scalar, so use the scalar limit for both
max_gqa = scalar_flash_attention_num_large_rows;
break;
case FA_COOPMAT2:
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
break;
default:
GGML_ASSERT(0);
}
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
@ -5785,11 +5894,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
} }
vk_pipeline *pipelines; vk_pipeline *pipelines;
// XXX TODO other backends may be changing accumulator precision to default to f32 soon bool small_rows = N <= get_fa_num_small_rows(path);
bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32;
bool small_rows = N <= get_fa_num_small_rows(scalar);
if (scalar) { if (small_rows && path == FA_COOPMAT1) {
path = FA_SCALAR;
}
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
switch (path) {
case FA_SCALAR:
switch (D) { switch (D) {
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
@ -5801,7 +5915,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
GGML_ASSERT(!"unsupported D value"); GGML_ASSERT(!"unsupported D value");
return; return;
} }
} else { break;
case FA_COOPMAT1:
switch (D) {
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
default:
GGML_ASSERT(!"unsupported D value");
return;
}
break;
case FA_COOPMAT2:
switch (D) { switch (D) {
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break; case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break; case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
@ -5813,6 +5941,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
GGML_ASSERT(!"unsupported D value"); GGML_ASSERT(!"unsupported D value");
return; return;
} }
break;
default:
GGML_ASSERT(0);
} }
assert(pipelines); assert(pipelines);

View file

@ -12,6 +12,7 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
layout (constant_id = 1) const uint32_t Br = 1; layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32; layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t D = 32; layout (constant_id = 3) const uint32_t D = 32;
@ -19,7 +20,7 @@ layout (constant_id = 3) const uint32_t D = 32;
layout (constant_id = 5) const uint32_t D_split = 16; layout (constant_id = 5) const uint32_t D_split = 16;
const uint32_t D_per_thread = D / D_split; const uint32_t D_per_thread = D / D_split;
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split; const uint32_t cols_per_iter = WorkGroupSize / D_split;
const uint32_t cols_per_thread = Bc / cols_per_iter; const uint32_t cols_per_thread = Bc / cols_per_iter;
layout (push_constant) uniform parameter { layout (push_constant) uniform parameter {
@ -134,8 +135,8 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
return ACC_TYPE(pow(base, ACC_TYPE(exph))); return ACC_TYPE(pow(base, ACC_TYPE(exph)));
} }
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; shared FLOAT_TYPE tmpsh[WorkGroupSize];
shared vec4 tmpshv4[gl_WorkGroupSize.x]; shared vec4 tmpshv4[WorkGroupSize];
shared float masksh[Bc][Br]; shared float masksh[Bc][Br];
shared vec4 Qf[Br][D / 4]; shared vec4 Qf[Br][D / 4];

View file

@ -0,0 +1,506 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_cooperative_matrix : enable
#include "types.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t D = 32;
layout (constant_id = 5) const uint32_t D_split = 16;
const uint32_t D_per_thread = D / D_split;
const uint32_t row_split = 4;
const uint32_t rows_per_thread = Br / row_split;
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
const uint32_t cols_per_thread = Bc / cols_per_iter;
layout (push_constant) uniform parameter {
uint32_t N;
uint32_t KV;
uint32_t ne1;
uint32_t ne2;
uint32_t ne3;
uint32_t neq2;
uint32_t neq3;
uint32_t nek2;
uint32_t nek3;
uint32_t nev2;
uint32_t nev3;
uint32_t nem1;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
uint32_t nb31;
float scale;
float max_bias;
float logit_softcap;
uint32_t mask;
uint32_t n_head_log2;
float m0;
float m1;
uint32_t gqa_ratio;
uint32_t split_kv;
uint32_t k_num;
} p;
layout (binding = 0) readonly buffer Q {float data_q[];};
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
layout (binding = 1) readonly buffer K {float16_t data_k[];};
layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
layout (binding = 2) readonly buffer V {float16_t data_v[];};
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
layout (binding = 3) readonly buffer M {float16_t data_m[];};
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16)
#define BINDING_IDX_K 0
#define BINDING_IDX_V 1
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
#endif
#if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
}
#endif
#if defined(DATA_A_Q8_0)
#define BLOCK_BYTE_SIZE 34
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
uint32_t offset = (iq2 + r) * D + c;
data_o[o_offset + offset] = D_TYPE(elem);
return elem;
}
// Store column zero. This is used to save per-row m and L values for split_k.
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
if (r < N && c == 0) {
uint32_t offset = iq2 + r;
data_o[o_offset + offset] = D_TYPE(elem);
}
return elem;
}
// Load the slope matrix, indexed by Q's dimension 2.
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
const uint32_t MatBr = 16;
const uint32_t MatBc = 16;
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
const uint32_t qstride = D / 4 + 2; // in units of f16vec4
shared f16vec4 Qf[Br * qstride];
// Avoid padding for D==256 to make it fit in 48KB shmem.
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
shared ACC_TYPE sfsh[Bc * sfshstride];
const uint32_t kshstride = D / 4 + 2; // in units of f16vec4
shared f16vec4 ksh[Bc * kshstride];
shared float slope[Br];
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
const uint32_t tid = gl_LocalInvocationIndex;
const uint32_t N = p.N;
const uint32_t KV = p.KV;
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
#define tile_row(r) (row_tid * rows_per_thread + (r))
uint32_t i = gl_WorkGroupID.x;
uint32_t split_k_index = 0;
if (p.k_num > 1) {
i = 0;
split_k_index = gl_WorkGroupID.x;
}
const uint32_t Tr = CEIL_DIV(N, Br);
const uint32_t start_j = split_k_index * p.split_kv / Bc;
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
const uint32_t iq3 = gl_WorkGroupID.z;
// broadcast factors
const uint32_t rk2 = p.neq2/p.nek2;
const uint32_t rk3 = p.neq3/p.nek3;
const uint32_t rv2 = p.neq2/p.nev2;
const uint32_t rv3 = p.neq3/p.nev3;
// k indices
const uint32_t ik3 = iq3 / rk3;
const uint32_t ik2 = iq2 / rk2;
// v indices
const uint32_t iv3 = iq3 / rv3;
const uint32_t iv2 = iq2 / rv2;
// nb?1 are already divided by the type size and are in units of elements.
// When using grouped query attention, Q is indexed by iq2, so the stride
// should be nb02 (which is in bytes).
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
uint32_t k_stride = p.nb11;
uint32_t v_stride = p.nb21;
// When using grouped query attention, all rows use the same mask (stride 0).
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
// that prevents the compiler from folding the "&" through the select
// and breaking the alignment detection.
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (D / 4);
uint32_t r = (idx + tid) / (D / 4);
if (r < Br && d < D / 4 &&
i * Br + r < N) {
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
}
}
barrier();
ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4];
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = ACC_TYPEV4(0.0);
}
}
float Lf[rows_per_thread], Mf[rows_per_thread];
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lf[r] = 0;
Mf[r] = NEG_FLT_MAX_OVER_2;
}
// ALiBi
if (p.max_bias > 0.0f) {
if (tid < Br) {
uint r = tid;
slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
}
barrier();
} else {
if (tid < Br) {
uint r = tid;
slope[r] = 1.0;
}
barrier();
}
#if BLOCK_SIZE > 1
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
#else
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (D / 4);
uint32_t c = (idx + tid) / (D / 4);
if (c < Bc && d < D / 4) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
#else
f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
#endif
ksh[c * kshstride + d] = K_Tf;
}
}
barrier();
// K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br
// Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
// This is written transposed in order to allow for N being 8 if implementations need it
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
for (uint32_t d = 0; d < D / 16; ++d) {
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
SfMat = coopMatMulAdd(KMat, QMat, SfMat);
}
uint coord = gl_SubgroupID * MatBc * sfshstride;
coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor);
barrier();
if (p.logit_softcap != 0.0f) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) / Br;
uint32_t r = (idx + tid) % Br;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
}
}
barrier();
}
if (p.mask != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]));
}
}
barrier();
}
float eMf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride];
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
}
float Moldf = Mf[r];
// M = max(rowmax, Mold)
// P = e^(S - M)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf, Moldf);
eMf[r] = exp(Moldf - Mf[r]);
}
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
}
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lf[r] = eMf[r]*Lf[r];
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
float Pf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
Lf[r] += Pf[r];
}
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
#endif
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] += float16_t(Pf[r]) * ACC_TYPEV4(Vf);
}
}
}
barrier();
}
// reduce across threads
float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
FLOAT_TYPE M = Mf[r];
tmpsh[tid] = M;
// Compute max across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
M = max(M, tmpsh[tid ^ s]);
barrier();
tmpsh[tid] = M;
barrier();
}
rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
barrier();
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Moldf[r] = Mf[r];
// M = max(rowmax, Mold)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf[r], Moldf[r]);
eMf[r] = exp(Moldf[r] - Mf[r]);
Lf[r] = eMf[r]*Lf[r];
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
FLOAT_TYPE L = Lf[r];
tmpsh[tid] = L;
// Compute sum across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
L += tmpsh[tid ^ s];
barrier();
tmpsh[tid] = L;
barrier();
}
Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
barrier();
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
tmpshv4[tid] = Of[r][d];
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
Of[r][d] += tmpshv4[tid ^ s];
barrier();
tmpshv4[tid] = Of[r][d];
barrier();
}
Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup];
barrier();
}
}
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
uint32_t o_offset = D * p.ne1 * split_k_index;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
}
}
}
}
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
}
}
return;
}
float Lfrcp[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lfrcp[r] = 1.0 / Lf[r];
}
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] *= float16_t(Lfrcp[r]);
}
}
uint32_t o_offset = iq3*p.ne2*p.ne1;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
}
}
}
}
} else {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (i * Br + tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
}
}
}
}
}
}

View file

@ -227,7 +227,7 @@ static std::mutex compile_count_mutex;
static std::condition_variable compile_count_cond; static std::condition_variable compile_count_cond;
void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
std::string out_fname = join_paths(output_dir, name + ".spv"); std::string out_fname = join_paths(output_dir, name + ".spv");
std::string in_path = join_paths(input_dir, in_fname); std::string in_path = join_paths(input_dir, in_fname);
@ -438,6 +438,7 @@ void process_shaders() {
// flash attention // flash attention
for (const auto& f16acc : {false, true}) { for (const auto& f16acc : {false, true}) {
std::string acctype = f16acc ? "float16_t" : "float"; std::string acctype = f16acc ? "float16_t" : "float";
std::string acctypev4 = f16acc ? "f16vec4" : "vec4";
for (const auto& tname : type_names) { for (const auto& tname : type_names) {
if (tname == "f32") { if (tname == "f32") {
@ -454,6 +455,16 @@ void process_shaders() {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
} }
#endif
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
}
#endif #endif
if (tname == "f16") { if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",

View file

@ -309,10 +309,10 @@ bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector<struct
return false; return false;
} }
} catch (std::length_error &) { } catch (std::length_error &) {
fprintf(stderr, "%s: encountered length_error while reading value for key '%s'\n", __func__, key.c_str()); GGML_LOG_ERROR("%s: encountered length_error while reading value for key '%s'\n", __func__, key.c_str());
return false; return false;
} catch (std::bad_alloc &) { } catch (std::bad_alloc &) {
fprintf(stderr, "%s: encountered bad_alloc error while reading value for key '%s'\n", __func__, key.c_str()); GGML_LOG_ERROR("%s: encountered bad_alloc error while reading value for key '%s'\n", __func__, key.c_str());
return false; return false;
} }
kv.emplace_back(key, value); kv.emplace_back(key, value);
@ -338,14 +338,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
ok = ok && gr.read(magic, 4); ok = ok && gr.read(magic, 4);
if (!ok) { if (!ok) {
fprintf(stderr, "%s: failed to read magic\n", __func__); GGML_LOG_ERROR("%s: failed to read magic\n", __func__);
gguf_free(ctx); gguf_free(ctx);
return nullptr; return nullptr;
} }
for (uint32_t i = 0; i < magic.size(); i++) { for (uint32_t i = 0; i < magic.size(); i++) {
if (magic[i] != GGUF_MAGIC[i]) { if (magic[i] != GGUF_MAGIC[i]) {
fprintf(stderr, "%s: invalid magic characters: '%c%c%c%c', expected 'GGUF'\n", __func__, magic[0], magic[1], magic[2], magic[3]); GGML_LOG_ERROR("%s: invalid magic characters: '%c%c%c%c', expected 'GGUF'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
gguf_free(ctx); gguf_free(ctx);
return nullptr; return nullptr;
} }
@ -393,7 +393,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
if (ok && gr.read(n_tensors)) { if (ok && gr.read(n_tensors)) {
static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing"); static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing");
if (n_tensors < 0 || n_tensors > int64_t(SIZE_MAX/sizeof(gguf_tensor_info))) { if (n_tensors < 0 || n_tensors > int64_t(SIZE_MAX/sizeof(gguf_tensor_info))) {
fprintf(stderr, "%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n", GGML_LOG_ERROR("%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n",
__func__, n_tensors, SIZE_MAX/sizeof(gguf_tensor_info)); __func__, n_tensors, SIZE_MAX/sizeof(gguf_tensor_info));
ok = false; ok = false;
} }
@ -404,7 +404,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
if (ok && gr.read(n_kv)) { if (ok && gr.read(n_kv)) {
static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing"); static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing");
if (n_kv < 0 || n_kv > int64_t(SIZE_MAX/sizeof(gguf_kv))) { if (n_kv < 0 || n_kv > int64_t(SIZE_MAX/sizeof(gguf_kv))) {
fprintf(stderr, "%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n", GGML_LOG_ERROR("%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n",
__func__, n_kv, SIZE_MAX/sizeof(gguf_kv)); __func__, n_kv, SIZE_MAX/sizeof(gguf_kv));
ok = false; ok = false;
} }
@ -414,7 +414,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
} }
if (!ok) { if (!ok) {
fprintf(stderr, "%s: failed to read header\n", __func__); GGML_LOG_ERROR("%s: failed to read header\n", __func__);
gguf_free(ctx); gguf_free(ctx);
return nullptr; return nullptr;
} }
@ -430,15 +430,15 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
try { try {
ok = ok && gr.read(key); ok = ok && gr.read(key);
} catch (std::length_error &) { } catch (std::length_error &) {
fprintf(stderr, "%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i); GGML_LOG_ERROR("%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i);
ok = false; ok = false;
} catch (std::bad_alloc &) { } catch (std::bad_alloc &) {
fprintf(stderr, "%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i); GGML_LOG_ERROR("%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i);
ok = false; ok = false;
} }
for (size_t j = 0; ok && j < ctx->kv.size(); ++j) { for (size_t j = 0; ok && j < ctx->kv.size(); ++j) {
if (key == ctx->kv[j].key) { if (key == ctx->kv[j].key) {
fprintf(stderr, "%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i); GGML_LOG_ERROR("%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i);
ok = false; ok = false;
} }
} }
@ -479,14 +479,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
case GGUF_TYPE_ARRAY: case GGUF_TYPE_ARRAY:
default: default:
{ {
fprintf(stderr, "%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type); GGML_LOG_ERROR("%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type);
ok = false; ok = false;
} break; } break;
} }
} }
if (!ok) { if (!ok) {
fprintf(stderr, "%s: failed to read key-value pairs\n", __func__); GGML_LOG_ERROR("%s: failed to read key-value pairs\n", __func__);
gguf_free(ctx); gguf_free(ctx);
return nullptr; return nullptr;
} }
@ -496,7 +496,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
ctx->alignment = alignment_idx == -1 ? GGUF_DEFAULT_ALIGNMENT : gguf_get_val_u32(ctx, alignment_idx); ctx->alignment = alignment_idx == -1 ? GGUF_DEFAULT_ALIGNMENT : gguf_get_val_u32(ctx, alignment_idx);
if (ctx->alignment == 0 || (ctx->alignment & (ctx->alignment - 1)) != 0) { if (ctx->alignment == 0 || (ctx->alignment & (ctx->alignment - 1)) != 0) {
fprintf(stderr, "%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment); GGML_LOG_ERROR("%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment);
gguf_free(ctx); gguf_free(ctx);
return nullptr; return nullptr;
} }
@ -512,14 +512,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
try { try {
ok = ok && gr.read(name); ok = ok && gr.read(name);
} catch (std::length_error &) { } catch (std::length_error &) {
fprintf(stderr, "%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i); GGML_LOG_ERROR("%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i);
ok = false; ok = false;
} catch (std::bad_alloc &) { } catch (std::bad_alloc &) {
fprintf(stderr, "%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i); GGML_LOG_ERROR("%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i);
ok = false; ok = false;
} }
if (name.length() >= GGML_MAX_NAME) { if (name.length() >= GGML_MAX_NAME) {
fprintf(stderr, "%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME); GGML_LOG_ERROR("%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME);
ok = false; ok = false;
break; break;
} }
@ -528,7 +528,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
// make sure there are no duplicate tensor names // make sure there are no duplicate tensor names
for (int64_t j = 0; ok && j < i; ++j) { for (int64_t j = 0; ok && j < i; ++j) {
if (strcmp(info.t.name, ctx->info[j].t.name) == 0) { if (strcmp(info.t.name, ctx->info[j].t.name) == 0) {
fprintf(stderr, "%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i); GGML_LOG_ERROR("%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i);
ok = false; ok = false;
break; break;
} }
@ -543,7 +543,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
uint32_t n_dims = -1; uint32_t n_dims = -1;
ok = ok && gr.read(n_dims); ok = ok && gr.read(n_dims);
if (n_dims > GGML_MAX_DIMS) { if (n_dims > GGML_MAX_DIMS) {
fprintf(stderr, "%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n", GGML_LOG_ERROR("%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n",
__func__, info.t.name, n_dims, GGML_MAX_DIMS); __func__, info.t.name, n_dims, GGML_MAX_DIMS);
ok = false; ok = false;
break; break;
@ -563,7 +563,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
// check that all ne are non-negative // check that all ne are non-negative
if (info.t.ne[j] < 0) { if (info.t.ne[j] < 0) {
fprintf(stderr, "%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n", GGML_LOG_ERROR("%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n",
__func__, info.t.name, j, info.t.ne[j]); __func__, info.t.name, j, info.t.ne[j]);
ok = false; ok = false;
break; break;
@ -575,7 +575,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
(INT64_MAX/info.t.ne[2] <= info.t.ne[0]*info.t.ne[1]) || (INT64_MAX/info.t.ne[2] <= info.t.ne[0]*info.t.ne[1]) ||
(INT64_MAX/info.t.ne[3] <= info.t.ne[0]*info.t.ne[1]*info.t.ne[2]))) { (INT64_MAX/info.t.ne[3] <= info.t.ne[0]*info.t.ne[1]*info.t.ne[2]))) {
fprintf(stderr, "%s: total number of elements in tensor '%s' with shape " GGML_LOG_ERROR("%s: total number of elements in tensor '%s' with shape "
"(%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") is >= %" PRIi64 "\n", "(%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") is >= %" PRIi64 "\n",
__func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], INT64_MAX); __func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], INT64_MAX);
ok = false; ok = false;
@ -592,7 +592,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
// check that tensor type is within defined range // check that tensor type is within defined range
if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) { if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) {
fprintf(stderr, "%s: tensor '%s' has invalid ggml type %d (%s)\n", GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d (%s)\n",
__func__, info.t.name, info.t.type, ggml_type_name(info.t.type)); __func__, info.t.name, info.t.type, ggml_type_name(info.t.type));
ok = false; ok = false;
break; break;
@ -602,7 +602,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
// check that row size is divisible by block size // check that row size is divisible by block size
if (blck_size == 0 || info.t.ne[0] % blck_size != 0) { if (blck_size == 0 || info.t.ne[0] % blck_size != 0) {
fprintf(stderr, "%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, " GGML_LOG_ERROR("%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, "
"not a multiple of block size (%" PRId64 ")\n", "not a multiple of block size (%" PRId64 ")\n",
__func__, info.t.name, (int) info.t.type, ggml_type_name(info.t.type), info.t.ne[0], blck_size); __func__, info.t.name, (int) info.t.type, ggml_type_name(info.t.type), info.t.ne[0], blck_size);
ok = false; ok = false;
@ -627,7 +627,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
} }
if (!ok) { if (!ok) {
fprintf(stderr, "%s: failed to read tensor info\n", __func__); GGML_LOG_ERROR("%s: failed to read tensor info\n", __func__);
gguf_free(ctx); gguf_free(ctx);
return nullptr; return nullptr;
} }
@ -635,7 +635,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
// we require the data section to be aligned, so take into account any padding // we require the data section to be aligned, so take into account any padding
if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) { if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) {
fprintf(stderr, "%s: failed to seek to beginning of data section\n", __func__); GGML_LOG_ERROR("%s: failed to seek to beginning of data section\n", __func__);
gguf_free(ctx); gguf_free(ctx);
return nullptr; return nullptr;
} }
@ -649,9 +649,9 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
for (size_t i = 0; i < ctx->info.size(); ++i) { for (size_t i = 0; i < ctx->info.size(); ++i) {
const gguf_tensor_info & ti = ctx->info[i]; const gguf_tensor_info & ti = ctx->info[i];
if (ti.offset != ctx->size) { if (ti.offset != ctx->size) {
fprintf(stderr, "%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n", GGML_LOG_ERROR("%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n",
__func__, ti.t.name, ti.offset, ctx->size); __func__, ti.t.name, ti.offset, ctx->size);
fprintf(stderr, "%s: failed to read tensor data\n", __func__); GGML_LOG_ERROR("%s: failed to read tensor data\n", __func__);
gguf_free(ctx); gguf_free(ctx);
return nullptr; return nullptr;
} }
@ -679,7 +679,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
*params.ctx = ggml_init(pdata); *params.ctx = ggml_init(pdata);
if (*params.ctx == nullptr) { if (*params.ctx == nullptr) {
fprintf(stderr, "%s: failed to initialize ggml context for storing tensors\n", __func__); GGML_LOG_ERROR("%s: failed to initialize ggml context for storing tensors\n", __func__);
gguf_free(ctx); gguf_free(ctx);
return nullptr; return nullptr;
} }
@ -701,7 +701,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
ok = ok && gr.read(data->data, ctx->size); ok = ok && gr.read(data->data, ctx->size);
if (!ok) { if (!ok) {
fprintf(stderr, "%s: failed to read tensor data binary blob\n", __func__); GGML_LOG_ERROR("%s: failed to read tensor data binary blob\n", __func__);
ggml_free(ctx_data); ggml_free(ctx_data);
*params.ctx = nullptr; *params.ctx = nullptr;
gguf_free(ctx); gguf_free(ctx);
@ -734,7 +734,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
} }
if (!ok) { if (!ok) {
fprintf(stderr, "%s: failed to create tensors\n", __func__); GGML_LOG_ERROR("%s: failed to create tensors\n", __func__);
ggml_free(ctx_data); ggml_free(ctx_data);
*params.ctx = nullptr; *params.ctx = nullptr;
gguf_free(ctx); gguf_free(ctx);
@ -751,7 +751,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
FILE * file = ggml_fopen(fname, "rb"); FILE * file = ggml_fopen(fname, "rb");
if (!file) { if (!file) {
fprintf(stderr, "%s: failed to open GGUF file '%s'\n", __func__, fname); GGML_LOG_ERROR("%s: failed to open GGUF file '%s'\n", __func__, fname);
return nullptr; return nullptr;
} }
@ -1350,7 +1350,7 @@ bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, boo
FILE * file = ggml_fopen(fname, "wb"); FILE * file = ggml_fopen(fname, "wb");
if (!file) { if (!file) {
fprintf(stderr, "%s: failed to open file '%s' for writing GGUF data\n", __func__, fname); GGML_LOG_ERROR("%s: failed to open file '%s' for writing GGUF data\n", __func__, fname);
return false; return false;
} }

View file

@ -823,6 +823,7 @@ class GGUFEditorWindow(QMainWindow):
self.modified = False self.modified = False
self.metadata_changes = {} # Store changes to apply when saving self.metadata_changes = {} # Store changes to apply when saving
self.metadata_to_remove = set() # Store keys to remove when saving self.metadata_to_remove = set() # Store keys to remove when saving
self.on_metadata_changed_is_connected = False
self.setup_ui() self.setup_ui()
@ -941,9 +942,11 @@ class GGUFEditorWindow(QMainWindow):
return return
# Disconnect to prevent triggering during loading # Disconnect to prevent triggering during loading
with warnings.catch_warnings(): if self.on_metadata_changed_is_connected:
warnings.filterwarnings('ignore') with warnings.catch_warnings():
self.metadata_table.itemChanged.disconnect(self.on_metadata_changed) warnings.filterwarnings('ignore')
self.metadata_table.itemChanged.disconnect(self.on_metadata_changed)
self.on_metadata_changed_is_connected = False
for i, (key, field) in enumerate(self.reader.fields.items()): for i, (key, field) in enumerate(self.reader.fields.items()):
self.metadata_table.insertRow(i) self.metadata_table.insertRow(i)
@ -1021,6 +1024,7 @@ class GGUFEditorWindow(QMainWindow):
# Reconnect after loading # Reconnect after loading
self.metadata_table.itemChanged.connect(self.on_metadata_changed) self.metadata_table.itemChanged.connect(self.on_metadata_changed)
self.on_metadata_changed_is_connected = True
def extract_array_values(self, field: ReaderField) -> list: def extract_array_values(self, field: ReaderField) -> list:
"""Extract all values from an array field.""" """Extract all values from an array field."""

View file

@ -68,7 +68,7 @@ class TensorNameMap:
"output_layer", # chatglm "output_layer", # chatglm
"head", # rwkv "head", # rwkv
"head.out", # wavtokenizer "head.out", # wavtokenizer
"language_model.lm_head", # llama4 "lm_head", # llama4
), ),
# Output norm # Output norm
@ -91,7 +91,7 @@ class TensorNameMap:
"rwkv.ln_out", # rwkv6 "rwkv.ln_out", # rwkv6
"model.ln_out", # rwkv7 "model.ln_out", # rwkv7
"backbone.final_layer_norm", # wavtokenizer "backbone.final_layer_norm", # wavtokenizer
"language_model.model.norm", # llama4 "model.norm", # llama4
), ),
# Rope frequencies # Rope frequencies
@ -133,7 +133,7 @@ class TensorNameMap:
"transformer.layers.{bid}.attn_norm", # openelm "transformer.layers.{bid}.attn_norm", # openelm
"rwkv.blocks.{bid}.ln1", # rwkv6 "rwkv.blocks.{bid}.ln1", # rwkv6
"model.layers.{bid}.ln1", # rwkv7 "model.layers.{bid}.ln1", # rwkv7
"language_model.model.layers.{bid}.input_layernorm", # llama4 "model.layers.{bid}.input_layernorm", # llama4
), ),
# Attention norm 2 # Attention norm 2
@ -173,7 +173,7 @@ class TensorNameMap:
"model.layers.{bid}.attention.wq", # internlm2 "model.layers.{bid}.attention.wq", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok "transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
"transformer.h.{bid}.attn.attention.q_proj", # exaone "transformer.h.{bid}.attn.attention.q_proj", # exaone
"language_model.model.layers.{bid}.self_attn.q_proj", # llama4 "model.layers.{bid}.self_attn.q_proj", # llama4
), ),
# Attention key # Attention key
@ -188,7 +188,7 @@ class TensorNameMap:
"model.layers.{bid}.attention.wk", # internlm2 "model.layers.{bid}.attention.wk", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok "transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
"transformer.h.{bid}.attn.attention.k_proj", # exaone "transformer.h.{bid}.attn.attention.k_proj", # exaone
"language_model.model.layers.{bid}.self_attn.k_proj", # llama4 "model.layers.{bid}.self_attn.k_proj", # llama4
), ),
# Attention value # Attention value
@ -202,7 +202,7 @@ class TensorNameMap:
"model.layers.{bid}.attention.wv", # internlm2 "model.layers.{bid}.attention.wv", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok "transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok
"transformer.h.{bid}.attn.attention.v_proj", # exaone "transformer.h.{bid}.attn.attention.v_proj", # exaone
"language_model.model.layers.{bid}.self_attn.v_proj", # llama4 "model.layers.{bid}.self_attn.v_proj", # llama4
), ),
# Attention output # Attention output
@ -229,7 +229,7 @@ class TensorNameMap:
"encoder.layers.{bid}.self_attention.dense", # chatglm "encoder.layers.{bid}.self_attention.dense", # chatglm
"transformer.layers.{bid}.attn.out_proj", # openelm "transformer.layers.{bid}.attn.out_proj", # openelm
"transformer.h.{bid}.attn.attention.out_proj", # exaone "transformer.h.{bid}.attn.attention.out_proj", # exaone
"language_model.model.layers.{bid}.self_attn.o_proj", # llama4 "model.layers.{bid}.self_attn.o_proj", # llama4
), ),
# Attention output norm # Attention output norm
@ -268,7 +268,7 @@ class TensorNameMap:
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok "transformer.decoder_layer.{bid}.rms_norm_2", # Grok
"encoder.layers.{bid}.post_attention_layernorm", # chatglm "encoder.layers.{bid}.post_attention_layernorm", # chatglm
"transformer.layers.{bid}.ffn_norm", # openelm "transformer.layers.{bid}.ffn_norm", # openelm
"language_model.model.layers.{bid}.post_attention_layernorm", # llama4 "model.layers.{bid}.post_attention_layernorm", # llama4
), ),
# Post feed-forward norm # Post feed-forward norm
@ -289,7 +289,7 @@ class TensorNameMap:
"transformer.decoder_layer.{bid}.router", # Grok "transformer.decoder_layer.{bid}.router", # Grok
"transformer.blocks.{bid}.ffn.router.layer", # dbrx "transformer.blocks.{bid}.ffn.router.layer", # dbrx
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe "model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
"language_model.model.layers.{bid}.feed_forward.router", # llama4 "model.layers.{bid}.feed_forward.router", # llama4
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe "encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
), ),
@ -329,7 +329,7 @@ class TensorNameMap:
"model.layers.{bid}.residual_mlp.w3", # arctic "model.layers.{bid}.residual_mlp.w3", # arctic
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
"transformer.h.{bid}.mlp.c_fc_1", # exaone "transformer.h.{bid}.mlp.c_fc_1", # exaone
"language_model.model.layers.{bid}.feed_forward.up_proj", # llama4 "model.layers.{bid}.feed_forward.up_proj", # llama4
), ),
MODEL_TENSOR.FFN_UP_EXP: ( MODEL_TENSOR.FFN_UP_EXP: (
@ -338,14 +338,14 @@ class TensorNameMap:
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged) "model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
"language_model.model.layers.{bid}.feed_forward.experts.up_proj", # llama4 "model.layers.{bid}.feed_forward.experts.up_proj", # llama4
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe "encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
), ),
MODEL_TENSOR.FFN_UP_SHEXP: ( MODEL_TENSOR.FFN_UP_SHEXP: (
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2 "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
"language_model.model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4 "model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
), ),
# AWQ-activation gate # AWQ-activation gate
@ -366,22 +366,22 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.linear_1", # refact "transformer.h.{bid}.mlp.linear_1", # refact
"model.layers.{bid}.residual_mlp.w1", # arctic "model.layers.{bid}.residual_mlp.w1", # arctic
"transformer.h.{bid}.mlp.c_fc_0", # exaone "transformer.h.{bid}.mlp.c_fc_0", # exaone
"language_model.model.layers.{bid}.feed_forward.gate_proj", # llama4 "model.layers.{bid}.feed_forward.gate_proj", # llama4
), ),
MODEL_TENSOR.FFN_GATE_EXP: ( MODEL_TENSOR.FFN_GATE_EXP: (
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged) "layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged) "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx "transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged) "model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
"language_model.model.layers.{bid}.feed_forward.experts.gate_proj", # llama4 "model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
), ),
MODEL_TENSOR.FFN_GATE_SHEXP: ( MODEL_TENSOR.FFN_GATE_SHEXP: (
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe "model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2 "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
"language_model.model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4 "model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
), ),
# Feed-forward down # Feed-forward down
@ -410,7 +410,7 @@ class TensorNameMap:
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
"model.layers.h.{bid}.mlp.c_proj", # exaone "model.layers.h.{bid}.mlp.c_proj", # exaone
"language_model.model.layers.{bid}.feed_forward.down_proj", # llama4 "model.layers.{bid}.feed_forward.down_proj", # llama4
), ),
MODEL_TENSOR.FFN_DOWN_EXP: ( MODEL_TENSOR.FFN_DOWN_EXP: (
@ -420,15 +420,15 @@ class TensorNameMap:
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe "model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged) "model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
"language_model.model.layers.{bid}.feed_forward.experts.down_proj", # llama4 "model.layers.{bid}.feed_forward.experts.down_proj", # llama4
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe "encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
), ),
MODEL_TENSOR.FFN_DOWN_SHEXP: ( MODEL_TENSOR.FFN_DOWN_SHEXP: (
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe "model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2 "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
"language_model.model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4 "model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe "model.layers.{bid}.shared_mlp.output_linear", # granitemoe
), ),
MODEL_TENSOR.ATTN_Q_NORM: ( MODEL_TENSOR.ATTN_Q_NORM: (

View file

@ -1704,10 +1704,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
} }
} }
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->state_write(io); if (kv_self != nullptr) {
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
kv_self->state_write(io);
}
return io.n_bytes(); return io.n_bytes();
} }

View file

@ -441,6 +441,13 @@ void llama_kv_cache_unified::defrag_sched(float thold) {
void llama_kv_cache_unified::set_full() { void llama_kv_cache_unified::set_full() {
n = size; n = size;
// when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
// affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
// we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so
// setting it to 0 is the simplest way to achieve that
// ref: https://github.com/ggml-org/llama.cpp/issues/13359
head = 0;
} }
llama_sbatch llama_kv_cache_unified::sbatch_init( llama_sbatch llama_kv_cache_unified::sbatch_init(
@ -1712,6 +1719,7 @@ void llama_kv_cache_recurrent::defrag_sched(float thold) {
void llama_kv_cache_recurrent::set_full() { void llama_kv_cache_recurrent::set_full() {
n = size; n = size;
head = 0;
} }
llama_sbatch llama_kv_cache_recurrent::sbatch_init( llama_sbatch llama_kv_cache_recurrent::sbatch_init(

View file

@ -171,11 +171,8 @@ public:
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
// Note: The value of head isn't only used to optimize searching uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
// for a free KV slot. llama_decode_impl also uses it, so it uint32_t size = 0; // total number of cells, shared across all sequences
// cannot be freely changed after a slot has been allocated.
uint32_t head = 0;
uint32_t size = 0;
uint32_t used = 0; // used cells (i.e. at least one seq_id) uint32_t used = 0; // used cells (i.e. at least one seq_id)
// computed before each graph build // computed before each graph build
@ -343,11 +340,8 @@ public:
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
// Note: The value of head isn't only used to optimize searching uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
// for a free KV slot. llama_decode_impl also uses it, so it uint32_t size = 0; // total number of cells, shared across all sequences
// cannot be freely changed after a slot has been allocated.
uint32_t head = 0;
uint32_t size = 0;
uint32_t used = 0; // used cells (i.e. at least one seq_id) uint32_t used = 0; // used cells (i.e. at least one seq_id)
// computed before each graph build // computed before each graph build

View file

@ -473,7 +473,7 @@ llama_model_loader::llama_model_loader(
meta.reset(gguf_init_from_file(fname.c_str(), params)); meta.reset(gguf_init_from_file(fname.c_str(), params));
if (!meta) { if (!meta) {
throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str())); throw std::runtime_error(format("%s: failed to load model from %s", __func__, fname.c_str()));
} }
get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
@ -532,7 +532,7 @@ llama_model_loader::llama_model_loader(
}; };
gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) }; gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) };
if (!ctx_gguf) { if (!ctx_gguf) {
throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, fname_split)); throw std::runtime_error(format("%s: failed to load GGUF split from %s", __func__, fname_split));
} }
// check idx // check idx
@ -827,13 +827,18 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps
mappings.reserve(files.size()); mappings.reserve(files.size());
mmaps_used.reserve(files.size()); mmaps_used.reserve(files.size());
for (const auto & file : files) { for (const auto & file : files) {
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU)); bool is_numa = false;
if (!reg) {
throw std::runtime_error(format("%s: no CPU backend found", __func__)); auto * dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (dev) {
auto * reg = ggml_backend_dev_backend_reg(dev);
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
if (is_numa_fn) {
is_numa = is_numa_fn();
}
} }
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa"); std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa);
std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa_fn());
mmaps_used.emplace_back(mapping->size(), 0); mmaps_used.emplace_back(mapping->size(), 0);
if (mlock_mmaps) { if (mlock_mmaps) {
std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock()); std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock());

View file

@ -12322,6 +12322,9 @@ struct llm_build_granite : public llm_graph_context {
// inp_pos - built only if rope enabled // inp_pos - built only if rope enabled
ggml_tensor * inp_pos = nullptr; ggml_tensor * inp_pos = nullptr;
if (use_rope) {
inp_pos = build_inp_pos();
}
auto * inp_attn = build_attn_inp_kv_unified(); auto * inp_attn = build_attn_inp_kv_unified();
@ -12364,10 +12367,6 @@ struct llm_build_granite : public llm_graph_context {
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
if (use_rope) { if (use_rope) {
if (!inp_pos) {
inp_pos = build_inp_pos();
}
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
Qcur = ggml_rope_ext( Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors, ctx0, Qcur, inp_pos, rope_factors,

View file

@ -14,6 +14,12 @@
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
// Quantization types. Changes to this struct must be replicated in quantize.cpp
struct tensor_quantization {
std::string name;
ggml_type quant = GGML_TYPE_COUNT;
};
static void zeros(std::ofstream & file, size_t n) { static void zeros(std::ofstream & file, size_t n) {
char zero = 0; char zero = 0;
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
@ -48,12 +54,6 @@ struct quantize_state_impl {
{} {}
}; };
// changes to this struct must be replicated in quantize.cpp
struct tensor_quantization {
std::string name;
ggml_type quant = GGML_TYPE_COUNT;
};
static void llama_tensor_dequantize_impl( static void llama_tensor_dequantize_impl(
ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers, ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
const size_t nelements, const int nthread const size_t nelements, const int nthread
@ -799,17 +799,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// unless the user specifies a type // unless the user specifies a type
if (params->tensor_types) { if (params->tensor_types) {
const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types); const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
const std::string tensor_name(tensor->name);
for (const auto & [tname, qtype] : tensor_types) { for (const auto & [tname, qtype] : tensor_types) {
if (std::regex pattern(tname); std::regex_search(tensor->name, pattern)) { if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) {
if (qtype != new_type) { if (qtype != new_type) {
LLAMA_LOG_DEBUG("(overriding %s -> %s), ", ggml_type_name(new_type), ggml_type_name(qtype)); LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type));
new_type = qtype;
break; // if two or more types are specified for the tensor, first match wins
} }
new_type = qtype;
break;
} }
} }
} }
} }
if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
new_type = params->token_embedding_type; new_type = params->token_embedding_type;
} }

View file

@ -0,0 +1,288 @@
// Tests common_regex (esp. its partial final matches support).
#include "common.h"
#include "regex-partial.h"
#include <sstream>
#include <iostream>
#include <optional>
template <class T> static void assert_equals(const T & expected, const T & actual) {
if (expected != actual) {
std::cerr << "Expected: " << expected << std::endl;
std::cerr << " Actual: " << actual << std::endl;
std::cerr << std::flush;
throw std::runtime_error("Test failed");
}
}
struct test_case {
std::string pattern;
struct input_output {
std::string input;
common_regex_match output;
};
std::vector<input_output> inputs_outputs;
};
static std::string common_regex_match_type_name(common_regex_match_type type) {
switch (type) {
case COMMON_REGEX_MATCH_TYPE_NONE:
return "COMMON_REGEX_MATCH_TYPE_NONE";
case COMMON_REGEX_MATCH_TYPE_PARTIAL:
return "COMMON_REGEX_MATCH_TYPE_PARTIAL";
case COMMON_REGEX_MATCH_TYPE_FULL:
return "COMMON_REGEX_MATCH_TYPE_FULL";
}
return "?";
}
static void test_regex() {
printf("[%s]\n", __func__);
auto test = [](const test_case & test_case) {
common_regex cr(test_case.pattern);
std::cout << "Testing pattern: /" << test_case.pattern << "/\n";
// std::cout << " partial rev: " << cr.reversed_partial_pattern.str() << '\n';
for (const auto & input_output : test_case.inputs_outputs) {
std::cout << " Input: " << input_output.input << '\n';
auto m = cr.search(input_output.input, 0);
if (m != input_output.output) {
auto match_to_str = [&](const std::optional<common_regex_match> & m) {
std::ostringstream ss;
if (m->type == COMMON_REGEX_MATCH_TYPE_NONE) {
ss << "<no match>";
} else {
GGML_ASSERT(!input_output.output.groups.empty());
std::vector<std::string> parts;
for (const auto & g : m->groups) {
parts.push_back("{" + std::to_string(g.begin) + ", " + std::to_string(g.end) + "}");
}
ss << "{" << common_regex_match_type_name(m->type) << ", {" << string_join(parts, ", ") << "}}";
}
return ss.str();
};
std::cout << " Expected: " << match_to_str(input_output.output) << '\n';
std::cout << " Got: " << match_to_str(m) << '\n';
std::cout << " Inverted pattern: /" << regex_to_reversed_partial_regex(test_case.pattern) << "/\n";
throw std::runtime_error("Test failed");
}
}
};
test({
"a",
{
{"a", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
{"b", {COMMON_REGEX_MATCH_TYPE_NONE, {}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
{"ba", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 2}}}},
}
});
test({
"abcd",
{
{"abcd", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
{"abcde", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
{"d", {}},
{"bcd", {}},
{"cde", {}},
{"cd", {}},
{"yeah ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{5, 7}}}},
{"abbie", {}},
{"", {}},
}
});
test({
".*?ab",
{
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
{"dab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
{"dabc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
{"da", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
}
});
test({
"a.*?b",
{
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
{"a b", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
{"argh", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
{"d", {}},
{"b", {}},
}
});
test({
"ab(?:cd){2,4}ef",
{
// {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, 0, {}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"abcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
{"abcde", {}},
{"abcdef", {}},
{"abcdcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"abcdcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 7}}}},
{"abcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
{"abcdcdcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 12}}}},
{"abcdcdcdcdcdef", {}},
{"abcde", {}},
{"yea", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{2, 3}}}},
}
});
test({
"a(?:rte| pure )fact",
{
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
{"art", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
{"artefa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"fact", {}},
{"an arte", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{3, 7}}}},
{"artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
{"an artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{3, 11}}}},
{"a pure", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 11}}}},
{"it's a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{5, 16}}}},
{"" , {}},
{"pure", {}},
{"pure fact", {}},
}
});
test({
"abc",
{
{" abcc", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 4}}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
{" ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{1, 3}}}},
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
{"b", {}},
{"c", {}},
{"", {}},
}
});
test({
"(?:abc)?\\s*def",
{
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
{"abc ", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
{"abc d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
{"abc de", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"abc def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
{"abc defg", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
{"abc defgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
{"abcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
{"abcdefgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 6}}}},
{" d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
}
});
test({
"a+b",
{
{"aaab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
{"aaa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
}
});
test({
"(?:"
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
"(" // match 2 (open_tag)
"<tool_call>"
"|<function_call>"
"|<tool>"
"|<tools>"
"|<response>"
"|<json>"
"|<xml>"
"|<JSON>"
")?"
"(\\s*\\{\\s*\"name\"\\s*:)" // match 3 (named tool call)
")"
"|<function=([^>]+)>" // match 4 (function name)
"|<function name=\"([^\"]+)\">", // match 5 (function name again)
{
{"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}, {54, 54}, {54, 54}, {0, 8}, {54, 54}, {54, 54}}}},
{"<tool_call> {\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 18}}}},
{"<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 17}}}},
{"Let's call something\n<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{21, 38}}}},
{"Ok then<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 24}}}},
{"{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"Ok then{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 13}}}},
{"<tool_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 20}, {66, 66}, {0, 11}, {11, 20}, {66, 66}, {66, 66}}}},
{"<function_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 24}, {70, 70}, {0, 15}, {15, 24}, {70, 70}, {70, 70}}}},
{"<function name=\"special_function\"> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 34}, {89, 89}, {89, 89}, {89, 89}, {89, 89}, {16, 32}}}},
{"<function=all>", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 14}, {14, 14}, {14, 14}, {14, 14}, {10, 13}, {14, 14}}}},
}
});
}
static void test_regex_to_reversed_partial_regex() {
printf("[%s]\n", __func__);
assert_equals<std::string>(
"((?:(?:c)?b)?a)[\\s\\S]*",
regex_to_reversed_partial_regex("abc"));
assert_equals<std::string>(
"(a+)[\\s\\S]*",
regex_to_reversed_partial_regex("a+"));
assert_equals<std::string>(
"(a*)[\\s\\S]*",
regex_to_reversed_partial_regex("a*"));
assert_equals<std::string>(
"(a?)[\\s\\S]*",
regex_to_reversed_partial_regex("a?"));
assert_equals<std::string>(
"([a-z])[\\s\\S]*",
regex_to_reversed_partial_regex("[a-z]"));
assert_equals<std::string>(
"((?:\\w+)?[a-z])[\\s\\S]*",
regex_to_reversed_partial_regex("[a-z]\\w+"));
assert_equals<std::string>(
"((?:a|b))[\\s\\S]*",
regex_to_reversed_partial_regex("(?:a|b)"));
assert_equals<std::string>(
"((?:(?:(?:d)?c)?b)?a)[\\s\\S]*",
regex_to_reversed_partial_regex("abcd"));
assert_equals<std::string>(
"((?:b)?a*)[\\s\\S]*", // TODO: ((?:b)?a*+).* ??
regex_to_reversed_partial_regex("a*b"));
assert_equals<std::string>(
"((?:(?:b)?a)?.*)[\\s\\S]*",
regex_to_reversed_partial_regex(".*?ab"));
assert_equals<std::string>(
"((?:(?:b)?.*)?a)[\\s\\S]*",
regex_to_reversed_partial_regex("a.*?b"));
assert_equals<std::string>(
"((?:(?:d)?(?:(?:c)?b))?a)[\\s\\S]*",
regex_to_reversed_partial_regex("a(bc)d"));
assert_equals<std::string>(
"((?:(?:(?:c)?b|(?:e)?d))?a)[\\s\\S]*",
regex_to_reversed_partial_regex("a(bc|de)"));
assert_equals<std::string>(
"((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)[\\s\\S]*",
regex_to_reversed_partial_regex("ab{2,4}c"));
}
int main() {
test_regex_to_reversed_partial_regex();
test_regex();
std::cout << "All tests passed.\n";
}

View file

@ -58,6 +58,12 @@ static const std::vector<quant_option> QUANT_OPTIONS = {
{ "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", }, { "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", },
}; };
// Quantization types. Changes to this struct must be replicated in llama-quantize.cpp
struct tensor_quantization {
std::string name;
ggml_type quant = GGML_TYPE_COUNT;
};
static const char * const LLM_KV_QUANTIZE_IMATRIX_FILE = "quantize.imatrix.file"; static const char * const LLM_KV_QUANTIZE_IMATRIX_FILE = "quantize.imatrix.file";
static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix.dataset"; static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix.dataset";
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count"; static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count";
@ -245,56 +251,10 @@ static ggml_type parse_ggml_type(const char * arg) {
return type; return type;
} }
} }
fprintf(stderr, "%s: invalid ggml_type '%s'\n", __func__, arg); fprintf(stderr, "\n%s: invalid ggml_type '%s'\n\n", __func__, arg);
return GGML_TYPE_COUNT; return GGML_TYPE_COUNT;
} }
// Allowed tensors for arbitrary quantization with --tensor-type option
static const std::vector<std::string> ALLOWED_TENSOR_TYPE = {
"attn_k",
"attn_kv_a_mqa",
"attn_kv_b",
"attn_o",
"attn_output",
"attn_q",
"attn_q_a",
"attn_q_b",
"attn_qkv",
"attn_v",
"channel_mix_key",
"channel_mix_receptance",
"channel_mix_value",
"cls",
"cls.output",
"cross_attn_k",
"cross_attn_o",
"cross_attn_q",
"cross_attn_v",
"ffn_act",
"ffn_down",
"ffn_down_exps",
"ffn_down_shexp",
"ffn_gate",
"ffn_gate_exps",
"ffn_gate_shexp",
"ffn_up",
"ffn_up_exps",
"ffn_up_shexp",
"ssm_in",
"ssm_out",
"time_mix_gate",
"time_mix_key",
"time_mix_output",
"time_mix_receptance",
"time_mix_value",
};
// changes to this struct must be replicated in llama-quant.cpp
struct tensor_quantization {
std::string name;
ggml_type quant = GGML_TYPE_COUNT;
};
static bool parse_tensor_type(const char * data, std::vector<tensor_quantization> & tensor_type) { static bool parse_tensor_type(const char * data, std::vector<tensor_quantization> & tensor_type) {
const char * sep = strchr(data, '='); const char * sep = strchr(data, '=');
if (sep == nullptr) { if (sep == nullptr) {
@ -307,7 +267,6 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
printf("\n%s: missing tensor name\n\n", __func__); printf("\n%s: missing tensor name\n\n", __func__);
return false; return false;
} }
if (const size_t qt_len = strlen(sep); qt_len == 1) { if (const size_t qt_len = strlen(sep); qt_len == 1) {
printf("\n%s: missing quantization type\n\n", __func__); printf("\n%s: missing quantization type\n\n", __func__);
return false; return false;
@ -316,37 +275,15 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
std::string tn(data, tn_len); std::string tn(data, tn_len);
std::transform(tn.begin(), tn.end(), tn.begin(), tolower); std::transform(tn.begin(), tn.end(), tn.begin(), tolower);
sep++; sep++;
const std::string qt(sep);
bool found = false;
for (const auto & allowed : ALLOWED_TENSOR_TYPE) {
std::string tensor;
tensor = tn.rfind('.') != std::string::npos ? tn.substr(tn.rfind('.') + 1) : tn;
// handle special case of cls.output
std::string cls_output = "cls.output";
if (tn.find(cls_output) != std::string::npos) {
tensor = "cls.output";
}
// check if an allowed tensor exists and it's at the end of the kv string
if (tensor == allowed) {
found = true;
break;
}
}
if (!found) {
printf("\n%s: invalid tensor name '%s'\n\n", __func__, tn.c_str());
return false;
}
if (parse_ggml_type(qt.c_str()) == GGML_TYPE_COUNT) {
printf("\n%s: invalid quantization type '%s'\n\n", __func__, qt.c_str());
return false;
}
tensor_quantization tqz; tensor_quantization tqz;
tqz.name = tn; tqz.name = tn;
tqz.quant = parse_ggml_type(qt.c_str()); tqz.quant = parse_ggml_type(sep);
tensor_type.emplace_back(std::move(tqz)); tensor_type.emplace_back(std::move(tqz));
if (tqz.quant == GGML_TYPE_COUNT) {
printf("\n%s: invalid quantization type '%s'\n\n", __func__, sep);
return false;
}
return true; return true;
} }

Binary file not shown.

View file

@ -1429,7 +1429,7 @@ struct server_slot {
pos = text.find(word, from_pos); pos = text.find(word, from_pos);
} else { } else {
// otherwise, partial stop // otherwise, partial stop
pos = find_partial_stop_string(word, text); pos = string_find_partial_stop(text, word);
} }
if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
@ -2951,7 +2951,8 @@ struct server_context {
llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard); llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard); llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
if (slot.params.cache_prompt) { // add generated tokens to cache
{
llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
new_tokens[i - n_discard] = new_tokens[i]; new_tokens[i - n_discard] = new_tokens[i];
@ -2996,10 +2997,7 @@ struct server_context {
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
slot.n_past += 1; slot.n_past += 1;
slot.cache_tokens.push_back(slot.sampled);
if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(slot.sampled);
}
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated); slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
@ -3171,6 +3169,11 @@ struct server_context {
SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past); SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
} }
} else {
// if we don't cache the prompt, we have to remove the entire KV cache
llama_kv_self_seq_rm(ctx, slot.id, 0, -1);
slot.n_past = 0;
slot.cache_tokens.clear();
} }
} }
@ -3204,7 +3207,7 @@ struct server_context {
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
// remove the non-common part from the cache // remove the non-common part from the cache
slot.cache_tokens.resize(slot.n_past); slot.cache_tokens.keep_first(slot.n_past);
// check if we should process the image // check if we should process the image
if (slot.n_past < slot.n_prompt_tokens if (slot.n_past < slot.n_prompt_tokens
@ -3221,7 +3224,8 @@ struct server_context {
continue; continue;
} }
if (slot.params.cache_prompt) { // add the image chunk to cache
{
const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past); const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past);
slot.cache_tokens.push_back(chunk.get()); // copy slot.cache_tokens.push_back(chunk.get()); // copy
} }
@ -3242,9 +3246,7 @@ struct server_context {
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd);
if (slot.params.cache_prompt) { slot.cache_tokens.push_back(cur_tok);
slot.cache_tokens.push_back(cur_tok);
}
slot.n_prompt_tokens_processed++; slot.n_prompt_tokens_processed++;
slot.n_past++; slot.n_past++;
@ -3705,6 +3707,9 @@ int main(int argc, char ** argv) {
if (req.path == "/" || tmp.back() == "html") { if (req.path == "/" || tmp.back() == "html") {
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8"); res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
res.status = 503; res.status = 503;
} else if (req.path == "/models" || req.path == "/v1/models") {
// allow the models endpoint to be accessed during loading
return true;
} else { } else {
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
} }
@ -4363,7 +4368,13 @@ int main(int argc, char ** argv) {
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
}; };
const auto handle_models = [&params, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { const auto handle_models = [&params, &ctx_server, &state, &res_ok](const httplib::Request &, httplib::Response & res) {
server_state current_state = state.load();
json model_meta = nullptr;
if (current_state == SERVER_STATE_READY) {
model_meta = ctx_server.model_meta();
}
json models = { json models = {
{"object", "list"}, {"object", "list"},
{"data", { {"data", {
@ -4372,7 +4383,7 @@ int main(int argc, char ** argv) {
{"object", "model"}, {"object", "model"},
{"created", std::time(0)}, {"created", std::time(0)},
{"owned_by", "llamacpp"}, {"owned_by", "llamacpp"},
{"meta", ctx_server.model_meta()} {"meta", model_meta},
}, },
}} }}
}; };

View file

@ -196,6 +196,18 @@ def test_cache_vs_nocache_prompt():
assert res_cache.body["content"] == res_no_cache.body["content"] assert res_cache.body["content"] == res_no_cache.body["content"]
def test_nocache_long_input_prompt():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is"*32,
"seed": 42,
"temperature": 1.0,
"cache_prompt": False,
})
assert res.status_code == 200
def test_completion_with_tokens_input(): def test_completion_with_tokens_input():
global server global server
server.temperature = 0.0 server.temperature = 0.0

View file

@ -0,0 +1,49 @@
#!/usr/bin/env python
import pytest
# ensure grandparent path is in sys.path
from pathlib import Path
import sys
from unit.test_tool_call import TEST_TOOL
path = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(path))
import datetime
from utils import *
server: ServerProcess
TIMEOUT_SERVER_START = 15*60
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.model_alias = "tinyllama-2"
server.server_port = 8081
server.n_slots = 1
@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
@pytest.mark.parametrize("template_name,format", [
("meta-llama-Llama-3.3-70B-Instruct", "%d %b %Y"),
("fireworks-ai-llama-3-firefunction-v2", "%b %d %Y"),
])
def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
global server
server.jinja = True
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
res = server.make_request("POST", "/apply-template", data={
"messages": [
{"role": "user", "content": "What is today?"},
],
"tools": tools,
})
assert res.status_code == 200
prompt = res.body["prompt"]
today_str = datetime.date.today().strftime(format)
assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"

View file

@ -109,7 +109,7 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
]) ])
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None): def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
global server global server
n_predict = 512 n_predict = 1024
# server = ServerPreset.stories15m_moe() # server = ServerPreset.stories15m_moe()
server.jinja = True server.jinja = True
server.n_predict = n_predict server.n_predict = n_predict

View file

@ -643,6 +643,18 @@ static json oaicompat_completion_params_parse(
throw std::runtime_error("Expected 'messages' to be an array"); throw std::runtime_error("Expected 'messages' to be an array");
} }
for (auto & msg : messages) { for (auto & msg : messages) {
std::string role = json_value(msg, "role", std::string());
if (role != "assistant" && !msg.contains("content")) {
throw std::runtime_error("All non-assistant messages must contain 'content'");
}
if (role == "assistant") {
if (!msg.contains("content") && !msg.contains("tool_calls")) {
throw std::runtime_error("Assistant message must contain either 'content' or 'tool_calls'!");
}
if (!msg.contains("content")) {
continue; // avoid errors with no content
}
}
json & content = msg.at("content"); json & content = msg.at("content");
if (content.is_string() || content.is_null()) { if (content.is_string() || content.is_null()) {
continue; continue;
@ -1153,7 +1165,7 @@ public:
tokens.clear(); tokens.clear();
} }
void resize(size_t n) { void keep_first(size_t n) {
GGML_ASSERT(n <= tokens.size()); GGML_ASSERT(n <= tokens.size());
if (has_mtmd) { if (has_mtmd) {
// we throw an error if we try to remove a token in the middle of an image // we throw an error if we try to remove a token in the middle of an image

View file

@ -18,6 +18,7 @@
"dexie": "^4.0.11", "dexie": "^4.0.11",
"highlight.js": "^11.10.0", "highlight.js": "^11.10.0",
"katex": "^0.16.15", "katex": "^0.16.15",
"pdfjs-dist": "^5.2.133",
"postcss": "^8.4.49", "postcss": "^8.4.49",
"react": "^18.3.1", "react": "^18.3.1",
"react-dom": "^18.3.1", "react-dom": "^18.3.1",
@ -44,6 +45,7 @@
"eslint": "^9.17.0", "eslint": "^9.17.0",
"eslint-plugin-react-hooks": "^5.0.0", "eslint-plugin-react-hooks": "^5.0.0",
"eslint-plugin-react-refresh": "^0.4.16", "eslint-plugin-react-refresh": "^0.4.16",
"fflate": "^0.8.2",
"globals": "^15.14.0", "globals": "^15.14.0",
"prettier": "^3.4.2", "prettier": "^3.4.2",
"sass-embedded": "^1.83.4", "sass-embedded": "^1.83.4",
@ -987,7 +989,7 @@
"version": "0.3.8", "version": "0.3.8",
"resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.8.tgz", "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.8.tgz",
"integrity": "sha512-imAbBGkb+ebQyxKgzv5Hu2nmROxoDOXHh80evxdoXNOrvAnVx7zimzc1Oo5h9RlfV4vPXaE2iM5pOFbvOCClWA==", "integrity": "sha512-imAbBGkb+ebQyxKgzv5Hu2nmROxoDOXHh80evxdoXNOrvAnVx7zimzc1Oo5h9RlfV4vPXaE2iM5pOFbvOCClWA==",
"dev": true, "devOptional": true,
"license": "MIT", "license": "MIT",
"dependencies": { "dependencies": {
"@jridgewell/set-array": "^1.2.1", "@jridgewell/set-array": "^1.2.1",
@ -1002,7 +1004,7 @@
"version": "3.1.2", "version": "3.1.2",
"resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz", "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz",
"integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==", "integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==",
"dev": true, "devOptional": true,
"license": "MIT", "license": "MIT",
"engines": { "engines": {
"node": ">=6.0.0" "node": ">=6.0.0"
@ -1012,30 +1014,224 @@
"version": "1.2.1", "version": "1.2.1",
"resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.2.1.tgz", "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.2.1.tgz",
"integrity": "sha512-R8gLRTZeyp03ymzP/6Lil/28tGeGEzhx1q2k703KGWRAI1VdvPIXdG70VJc2pAMw3NA6JKL5hhFu1sJX0Mnn/A==", "integrity": "sha512-R8gLRTZeyp03ymzP/6Lil/28tGeGEzhx1q2k703KGWRAI1VdvPIXdG70VJc2pAMw3NA6JKL5hhFu1sJX0Mnn/A==",
"dev": true, "devOptional": true,
"license": "MIT", "license": "MIT",
"engines": { "engines": {
"node": ">=6.0.0" "node": ">=6.0.0"
} }
}, },
"node_modules/@jridgewell/source-map": {
"version": "0.3.6",
"resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.6.tgz",
"integrity": "sha512-1ZJTZebgqllO79ue2bm3rIGud/bOe0pP5BjSRCRxxYkEZS8STV7zN84UBbiYu7jy+eCKSnVIUgoWWE/tt+shMQ==",
"license": "MIT",
"optional": true,
"peer": true,
"dependencies": {
"@jridgewell/gen-mapping": "^0.3.5",
"@jridgewell/trace-mapping": "^0.3.25"
}
},
"node_modules/@jridgewell/sourcemap-codec": { "node_modules/@jridgewell/sourcemap-codec": {
"version": "1.5.0", "version": "1.5.0",
"resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.0.tgz", "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.0.tgz",
"integrity": "sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==", "integrity": "sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==",
"dev": true, "devOptional": true,
"license": "MIT" "license": "MIT"
}, },
"node_modules/@jridgewell/trace-mapping": { "node_modules/@jridgewell/trace-mapping": {
"version": "0.3.25", "version": "0.3.25",
"resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.25.tgz", "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.25.tgz",
"integrity": "sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==", "integrity": "sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==",
"dev": true, "devOptional": true,
"license": "MIT", "license": "MIT",
"dependencies": { "dependencies": {
"@jridgewell/resolve-uri": "^3.1.0", "@jridgewell/resolve-uri": "^3.1.0",
"@jridgewell/sourcemap-codec": "^1.4.14" "@jridgewell/sourcemap-codec": "^1.4.14"
} }
}, },
"node_modules/@napi-rs/canvas": {
"version": "0.1.70",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas/-/canvas-0.1.70.tgz",
"integrity": "sha512-nD6NGa4JbNYSZYsTnLGrqe9Kn/lCkA4ybXt8sx5ojDqZjr2i0TWAHxx/vhgfjX+i3hCdKWufxYwi7CfXqtITSA==",
"license": "MIT",
"optional": true,
"engines": {
"node": ">= 10"
},
"optionalDependencies": {
"@napi-rs/canvas-android-arm64": "0.1.70",
"@napi-rs/canvas-darwin-arm64": "0.1.70",
"@napi-rs/canvas-darwin-x64": "0.1.70",
"@napi-rs/canvas-linux-arm-gnueabihf": "0.1.70",
"@napi-rs/canvas-linux-arm64-gnu": "0.1.70",
"@napi-rs/canvas-linux-arm64-musl": "0.1.70",
"@napi-rs/canvas-linux-riscv64-gnu": "0.1.70",
"@napi-rs/canvas-linux-x64-gnu": "0.1.70",
"@napi-rs/canvas-linux-x64-musl": "0.1.70",
"@napi-rs/canvas-win32-x64-msvc": "0.1.70"
}
},
"node_modules/@napi-rs/canvas-android-arm64": {
"version": "0.1.70",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-android-arm64/-/canvas-android-arm64-0.1.70.tgz",
"integrity": "sha512-I/YOuQ0wbkVYxVaYtCgN42WKTYxNqFA0gTcTrHIGG1jfpDSyZWII/uHcjOo4nzd19io6Y4+/BqP8E5hJgf9OmQ==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"android"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@napi-rs/canvas-darwin-arm64": {
"version": "0.1.70",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-arm64/-/canvas-darwin-arm64-0.1.70.tgz",
"integrity": "sha512-4pPGyXetHIHkw2TOJHujt3mkCP8LdDu8+CT15ld9Id39c752RcI0amDHSuMLMQfAjvusA9B5kKxazwjMGjEJpQ==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"darwin"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@napi-rs/canvas-darwin-x64": {
"version": "0.1.70",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-x64/-/canvas-darwin-x64-0.1.70.tgz",
"integrity": "sha512-+2N6Os9LbkmDMHL+raknrUcLQhsXzc5CSXRbXws9C3pv/mjHRVszQ9dhFUUe9FjfPhCJznO6USVdwOtu7pOrzQ==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"darwin"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@napi-rs/canvas-linux-arm-gnueabihf": {
"version": "0.1.70",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm-gnueabihf/-/canvas-linux-arm-gnueabihf-0.1.70.tgz",
"integrity": "sha512-QjscX9OaKq/990sVhSMj581xuqLgiaPVMjjYvWaCmAJRkNQ004QfoSMEm3FoTqM4DRoquP8jvuEXScVJsc1rqQ==",
"cpu": [
"arm"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@napi-rs/canvas-linux-arm64-gnu": {
"version": "0.1.70",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-gnu/-/canvas-linux-arm64-gnu-0.1.70.tgz",
"integrity": "sha512-LNakMOwwqwiHIwMpnMAbFRczQMQ7TkkMyATqFCOtUJNlE6LPP/QiUj/mlFrNbUn/hctqShJ60gWEb52ZTALbVw==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@napi-rs/canvas-linux-arm64-musl": {
"version": "0.1.70",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-musl/-/canvas-linux-arm64-musl-0.1.70.tgz",
"integrity": "sha512-wBTOllEYNfJCHOdZj9v8gLzZ4oY3oyPX8MSRvaxPm/s7RfEXxCyZ8OhJ5xAyicsDdbE5YBZqdmaaeP5+xKxvtg==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@napi-rs/canvas-linux-riscv64-gnu": {
"version": "0.1.70",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-riscv64-gnu/-/canvas-linux-riscv64-gnu-0.1.70.tgz",
"integrity": "sha512-GVUUPC8TuuFqHip0rxHkUqArQnlzmlXmTEBuXAWdgCv85zTCFH8nOHk/YCF5yo0Z2eOm8nOi90aWs0leJ4OE5Q==",
"cpu": [
"riscv64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@napi-rs/canvas-linux-x64-gnu": {
"version": "0.1.70",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-gnu/-/canvas-linux-x64-gnu-0.1.70.tgz",
"integrity": "sha512-/kvUa2lZRwGNyfznSn5t1ShWJnr/m5acSlhTV3eXECafObjl0VBuA1HJw0QrilLpb4Fe0VLywkpD1NsMoVDROQ==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@napi-rs/canvas-linux-x64-musl": {
"version": "0.1.70",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-musl/-/canvas-linux-x64-musl-0.1.70.tgz",
"integrity": "sha512-aqlv8MLpycoMKRmds7JWCfVwNf1fiZxaU7JwJs9/ExjTD8lX2KjsO7CTeAj5Cl4aEuzxUWbJPUUE2Qu9cZ1vfg==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@napi-rs/canvas-win32-x64-msvc": {
"version": "0.1.70",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-win32-x64-msvc/-/canvas-win32-x64-msvc-0.1.70.tgz",
"integrity": "sha512-Q9QU3WIpwBTVHk4cPfBjGHGU4U0llQYRXgJtFtYqqGNEOKVN4OT6PQ+ve63xwIPODMpZ0HHyj/KLGc9CWc3EtQ==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"win32"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@nodelib/fs.scandir": { "node_modules/@nodelib/fs.scandir": {
"version": "2.1.5", "version": "2.1.5",
"resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz",
@ -2001,7 +2197,7 @@
"version": "8.14.0", "version": "8.14.0",
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.14.0.tgz", "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.14.0.tgz",
"integrity": "sha512-cl669nCJTZBsL97OF4kUQm5g5hC2uihk0NxY3WENAC0TYdILVkAyHymAntgxGkl7K+t0cXIrH5siy5S4XkFycA==", "integrity": "sha512-cl669nCJTZBsL97OF4kUQm5g5hC2uihk0NxY3WENAC0TYdILVkAyHymAntgxGkl7K+t0cXIrH5siy5S4XkFycA==",
"dev": true, "devOptional": true,
"license": "MIT", "license": "MIT",
"bin": { "bin": {
"acorn": "bin/acorn" "acorn": "bin/acorn"
@ -2185,6 +2381,14 @@
"devOptional": true, "devOptional": true,
"license": "MIT/X11" "license": "MIT/X11"
}, },
"node_modules/buffer-from": {
"version": "1.1.2",
"resolved": "https://registry.npmjs.org/buffer-from/-/buffer-from-1.1.2.tgz",
"integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==",
"license": "MIT",
"optional": true,
"peer": true
},
"node_modules/callsites": { "node_modules/callsites": {
"version": "3.1.0", "version": "3.1.0",
"resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz", "resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz",
@ -2802,6 +3006,13 @@
"reusify": "^1.0.4" "reusify": "^1.0.4"
} }
}, },
"node_modules/fflate": {
"version": "0.8.2",
"resolved": "https://registry.npmjs.org/fflate/-/fflate-0.8.2.tgz",
"integrity": "sha512-cPJU47OaAoCbg0pBvzsgpTPhmhqI5eJjh/JIu8tPj5q+T7iLvW/JAYUqmE7KOB4R1ZyEhzBaIQpQpardBF5z8A==",
"dev": true,
"license": "MIT"
},
"node_modules/file-entry-cache": { "node_modules/file-entry-cache": {
"version": "8.0.0", "version": "8.0.0",
"resolved": "https://registry.npmjs.org/file-entry-cache/-/file-entry-cache-8.0.0.tgz", "resolved": "https://registry.npmjs.org/file-entry-cache/-/file-entry-cache-8.0.0.tgz",
@ -4835,6 +5046,18 @@
"node": ">=8" "node": ">=8"
} }
}, },
"node_modules/pdfjs-dist": {
"version": "5.2.133",
"resolved": "https://registry.npmjs.org/pdfjs-dist/-/pdfjs-dist-5.2.133.tgz",
"integrity": "sha512-abE6ZWDxztt+gGFzfm4bX2ggfxUk9wsDEoFzIJm9LozaY3JdXR7jyLK4Bjs+XLXplCduuWS1wGhPC4tgTn/kzg==",
"license": "Apache-2.0",
"engines": {
"node": ">=20.16.0 || >=22.3.0"
},
"optionalDependencies": {
"@napi-rs/canvas": "^0.1.67"
}
},
"node_modules/picocolors": { "node_modules/picocolors": {
"version": "1.1.1", "version": "1.1.1",
"resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz",
@ -5745,6 +5968,17 @@
"node": ">=8" "node": ">=8"
} }
}, },
"node_modules/source-map": {
"version": "0.6.1",
"resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz",
"integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==",
"license": "BSD-3-Clause",
"optional": true,
"peer": true,
"engines": {
"node": ">=0.10.0"
}
},
"node_modules/source-map-js": { "node_modules/source-map-js": {
"version": "1.2.1", "version": "1.2.1",
"resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz", "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz",
@ -5754,6 +5988,18 @@
"node": ">=0.10.0" "node": ">=0.10.0"
} }
}, },
"node_modules/source-map-support": {
"version": "0.5.21",
"resolved": "https://registry.npmjs.org/source-map-support/-/source-map-support-0.5.21.tgz",
"integrity": "sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w==",
"license": "MIT",
"optional": true,
"peer": true,
"dependencies": {
"buffer-from": "^1.0.0",
"source-map": "^0.6.0"
}
},
"node_modules/space-separated-tokens": { "node_modules/space-separated-tokens": {
"version": "2.0.2", "version": "2.0.2",
"resolved": "https://registry.npmjs.org/space-separated-tokens/-/space-separated-tokens-2.0.2.tgz", "resolved": "https://registry.npmjs.org/space-separated-tokens/-/space-separated-tokens-2.0.2.tgz",
@ -5851,6 +6097,34 @@
"node": ">=6" "node": ">=6"
} }
}, },
"node_modules/terser": {
"version": "5.39.1",
"resolved": "https://registry.npmjs.org/terser/-/terser-5.39.1.tgz",
"integrity": "sha512-Mm6+uad0ZuDtcV8/4uOZQDQ8RuiC5Pu+iZRedJtF7yA/27sPL7d++In/AJKpWZlU3SYMPPkVfwetn6sgZ66pUA==",
"license": "BSD-2-Clause",
"optional": true,
"peer": true,
"dependencies": {
"@jridgewell/source-map": "^0.3.3",
"acorn": "^8.8.2",
"commander": "^2.20.0",
"source-map-support": "~0.5.20"
},
"bin": {
"terser": "bin/terser"
},
"engines": {
"node": ">=10"
}
},
"node_modules/terser/node_modules/commander": {
"version": "2.20.3",
"resolved": "https://registry.npmjs.org/commander/-/commander-2.20.3.tgz",
"integrity": "sha512-GpVkmM8vF2vQUkj2LvZmD35JxeJOLCwJ9cUkugyk2nuhbv3+mJvpLYYt+0+USMxE+oj+ey/lJEnhZw75x/OMcQ==",
"license": "MIT",
"optional": true,
"peer": true
},
"node_modules/textlinestream": { "node_modules/textlinestream": {
"version": "1.1.1", "version": "1.1.1",
"resolved": "https://registry.npmjs.org/textlinestream/-/textlinestream-1.1.1.tgz", "resolved": "https://registry.npmjs.org/textlinestream/-/textlinestream-1.1.1.tgz",

View file

@ -5,7 +5,7 @@
"type": "module", "type": "module",
"scripts": { "scripts": {
"dev": "vite", "dev": "vite",
"build": "tsc -b && vite build", "build": "npm run format && tsc -b && vite build",
"format": "eslint . && prettier --write .", "format": "eslint . && prettier --write .",
"lint": "eslint .", "lint": "eslint .",
"preview": "vite preview" "preview": "vite preview"
@ -21,6 +21,7 @@
"dexie": "^4.0.11", "dexie": "^4.0.11",
"highlight.js": "^11.10.0", "highlight.js": "^11.10.0",
"katex": "^0.16.15", "katex": "^0.16.15",
"pdfjs-dist": "^5.2.133",
"postcss": "^8.4.49", "postcss": "^8.4.49",
"react": "^18.3.1", "react": "^18.3.1",
"react-dom": "^18.3.1", "react-dom": "^18.3.1",
@ -47,6 +48,7 @@
"eslint": "^9.17.0", "eslint": "^9.17.0",
"eslint-plugin-react-hooks": "^5.0.0", "eslint-plugin-react-hooks": "^5.0.0",
"eslint-plugin-react-refresh": "^0.4.16", "eslint-plugin-react-refresh": "^0.4.16",
"fflate": "^0.8.2",
"globals": "^15.14.0", "globals": "^15.14.0",
"prettier": "^3.4.2", "prettier": "^3.4.2",
"sass-embedded": "^1.83.4", "sass-embedded": "^1.83.4",

View file

@ -16,6 +16,8 @@ export const CONFIG_DEFAULT = {
showTokensPerSecond: false, showTokensPerSecond: false,
showThoughtInProgress: false, showThoughtInProgress: false,
excludeThoughtOnReq: true, excludeThoughtOnReq: true,
pasteLongTextToFileLen: 2500,
pdfAsImage: false,
// make sure these default values are in sync with `common.h` // make sure these default values are in sync with `common.h`
samplers: 'edkypmxt', samplers: 'edkypmxt',
temperature: 0.8, temperature: 0.8,
@ -43,6 +45,8 @@ export const CONFIG_DEFAULT = {
export const CONFIG_INFO: Record<string, string> = { export const CONFIG_INFO: Record<string, string> = {
apiKey: 'Set the API Key if you are using --api-key option for the server.', apiKey: 'Set the API Key if you are using --api-key option for the server.',
systemMessage: 'The starting message that defines how model should behave.', systemMessage: 'The starting message that defines how model should behave.',
pasteLongTextToFileLen:
'On pasting long text, it will be converted to a file. You can control the file length by setting the value of this parameter. Value 0 means disable.',
samplers: samplers:
'The order at which samplers are applied, in simplified way. Default is "dkypmxt": dry->top_k->typ_p->top_p->min_p->xtc->temperature', 'The order at which samplers are applied, in simplified way. Default is "dkypmxt": dry->top_k->typ_p->top_p->min_p->xtc->temperature',
temperature: temperature:

View file

@ -1,4 +1,4 @@
import { useEffect, useMemo, useRef, useState } from 'react'; import { ClipboardEvent, useEffect, useMemo, useRef, useState } from 'react';
import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context'; import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context';
import ChatMessage from './ChatMessage'; import ChatMessage from './ChatMessage';
import { CanvasType, Message, PendingMessage } from '../utils/types'; import { CanvasType, Message, PendingMessage } from '../utils/types';
@ -306,6 +306,7 @@ function ChatInput({
onStop: () => void; onStop: () => void;
isGenerating: boolean; isGenerating: boolean;
}) { }) {
const { config } = useAppContext();
const [isDrag, setIsDrag] = useState(false); const [isDrag, setIsDrag] = useState(false);
return ( return (
@ -328,6 +329,38 @@ function ChatInput({
{({ getRootProps, getInputProps }) => ( {({ getRootProps, getInputProps }) => (
<div <div
className="flex flex-col rounded-xl border-1 border-base-content/30 p-3 w-full" className="flex flex-col rounded-xl border-1 border-base-content/30 p-3 w-full"
// when a file is pasted to the input, we handle it here
// if a text is pasted, and if it is long text, we will convert it to a file
onPasteCapture={(e: ClipboardEvent<HTMLInputElement>) => {
const text = e.clipboardData.getData('text/plain');
if (
text.length > 0 &&
config.pasteLongTextToFileLen > 0 &&
text.length > config.pasteLongTextToFileLen
) {
// if the text is too long, we will convert it to a file
extraContext.addItems([
{
type: 'context',
name: 'Pasted Content',
content: text,
},
]);
e.preventDefault();
return;
}
// if a file is pasted, we will handle it here
const files = Array.from(e.clipboardData.items)
.filter((item) => item.kind === 'file')
.map((item) => item.getAsFile())
.filter((file) => file !== null);
if (files.length > 0) {
e.preventDefault();
extraContext.onFileAdded(files);
}
}}
{...getRootProps()} {...getRootProps()}
> >
{!isGenerating && ( {!isGenerating && (

View file

@ -100,6 +100,16 @@ const SETTING_SECTIONS: SettingSection[] = [
key, key,
}) as SettingFieldInput }) as SettingFieldInput
), ),
{
type: SettingInputType.SHORT_INPUT,
label: 'Paste length to file',
key: 'pasteLongTextToFileLen',
},
{
type: SettingInputType.CHECKBOX,
label: 'Parse PDF as image instead of text',
key: 'pdfAsImage',
},
], ],
}, },
{ {
@ -452,10 +462,10 @@ function SettingsModalLongInput({
label?: string; label?: string;
}) { }) {
return ( return (
<label className="form-control mb-2"> <label className="form-control">
<div className="label inline">{label || configKey}</div> <div className="label inline text-sm">{label || configKey}</div>
<textarea <textarea
className="textarea textarea-bordered h-24" className="textarea textarea-bordered h-24 mb-2"
placeholder={`Default: ${CONFIG_DEFAULT[configKey] || 'none'}`} placeholder={`Default: ${CONFIG_DEFAULT[configKey] || 'none'}`}
value={value} value={value}
onChange={(e) => onChange(e.target.value)} onChange={(e) => onChange(e.target.value)}
@ -482,9 +492,7 @@ function SettingsModalShortInput({
<> <>
{/* on mobile, we simply show the help message here */} {/* on mobile, we simply show the help message here */}
{helpMsg && ( {helpMsg && (
<div className="block md:hidden mb-1"> <div className="block mb-1 opacity-75">
<b>{label || configKey}</b>
<br />
<p className="text-xs">{helpMsg}</p> <p className="text-xs">{helpMsg}</p>
</div> </div>
)} )}
@ -493,11 +501,6 @@ function SettingsModalShortInput({
<div tabIndex={0} role="button" className="font-bold hidden md:block"> <div tabIndex={0} role="button" className="font-bold hidden md:block">
{label || configKey} {label || configKey}
</div> </div>
{helpMsg && (
<div className="dropdown-content menu bg-base-100 rounded-box z-10 w-64 p-2 shadow mt-4">
{helpMsg}
</div>
)}
</div> </div>
<input <input
type="text" type="text"

View file

@ -2,6 +2,17 @@ import { useState } from 'react';
import { MessageExtra } from '../utils/types'; import { MessageExtra } from '../utils/types';
import toast from 'react-hot-toast'; import toast from 'react-hot-toast';
import { useAppContext } from '../utils/app.context'; import { useAppContext } from '../utils/app.context';
import * as pdfjs from 'pdfjs-dist';
import pdfjsWorkerSrc from 'pdfjs-dist/build/pdf.worker.min.mjs?url';
import { TextContent, TextItem } from 'pdfjs-dist/types/src/display/api';
pdfjs.GlobalWorkerOptions.workerSrc = pdfjsWorkerSrc;
// This file handles uploading extra context items (a.k.a files)
// It allows processing these kinds of files:
// - image files (converted to base64)
// - text files (including code files)
// - pdf (converted to text)
// Interface describing the API returned by the hook // Interface describing the API returned by the hook
export interface ChatExtraContextApi { export interface ChatExtraContextApi {
@ -13,7 +24,7 @@ export interface ChatExtraContextApi {
} }
export function useChatExtraContext(): ChatExtraContextApi { export function useChatExtraContext(): ChatExtraContextApi {
const { serverProps } = useAppContext(); const { serverProps, config } = useAppContext();
const [items, setItems] = useState<MessageExtra[]>([]); const [items, setItems] = useState<MessageExtra[]>([]);
const addItems = (newItems: MessageExtra[]) => { const addItems = (newItems: MessageExtra[]) => {
@ -28,6 +39,8 @@ export function useChatExtraContext(): ChatExtraContextApi {
setItems([]); setItems([]);
}; };
const isSupportVision = serverProps?.modalities?.vision;
const onFileAdded = (files: File[]) => { const onFileAdded = (files: File[]) => {
for (const file of files) { for (const file of files) {
const mimeType = file.type; const mimeType = file.type;
@ -38,7 +51,7 @@ export function useChatExtraContext(): ChatExtraContextApi {
} }
if (mimeType.startsWith('image/')) { if (mimeType.startsWith('image/')) {
if (!serverProps?.modalities?.vision) { if (!isSupportVision) {
toast.error('Multimodal is not supported by this server or model.'); toast.error('Multimodal is not supported by this server or model.');
break; break;
} }
@ -69,7 +82,43 @@ export function useChatExtraContext(): ChatExtraContextApi {
toast.error('Video and audio files are not supported yet.'); toast.error('Video and audio files are not supported yet.');
break; break;
} else if (mimeType.startsWith('application/pdf')) { } else if (mimeType.startsWith('application/pdf')) {
toast.error('PDF files are not supported yet.'); if (config.pdfAsImage && !isSupportVision) {
toast(
'Multimodal is not supported, PDF will be converted to text instead of image.'
);
break;
}
const promise =
config.pdfAsImage && isSupportVision
? convertPDFToImage(file).then((base64Urls) => {
addItems(
base64Urls.map((base64Url) => ({
type: 'imageFile',
name: file.name,
base64Url,
}))
);
})
: convertPDFToText(file).then((content) => {
if (isSupportVision) {
toast.success(
'PDF file converted to text. You can also convert it to image, see in Settings.'
);
}
addItems([
{
type: 'textFile',
name: file.name,
content,
},
]);
});
promise.catch((error) => {
console.error(error);
toast.error('Failed to parse PDF file.');
});
break; break;
} else { } else {
// Because there can be many text file types (like code file), we will not check the mime type // Because there can be many text file types (like code file), we will not check the mime type
@ -105,11 +154,69 @@ export function useChatExtraContext(): ChatExtraContextApi {
}; };
} }
async function getFileAsBuffer(file: File): Promise<ArrayBuffer> {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onload = (event) => {
if (event.target?.result) {
resolve(event.target.result as ArrayBuffer);
} else {
reject(new Error('Failed to read file.'));
}
};
reader.readAsArrayBuffer(file);
});
}
async function convertPDFToText(file: File): Promise<string> {
const buffer = await getFileAsBuffer(file);
const pdf = await pdfjs.getDocument(buffer).promise;
const numPages = pdf.numPages;
const textContentPromises: Promise<TextContent>[] = [];
for (let i = 1; i <= numPages; i++) {
textContentPromises.push(
pdf.getPage(i).then((page) => page.getTextContent())
);
}
const textContents = await Promise.all(textContentPromises);
const textItems = textContents.flatMap((textContent: TextContent) =>
textContent.items.map((item) => (item as TextItem).str ?? '')
);
return textItems.join('\n');
}
// returns list of base64 images
async function convertPDFToImage(file: File): Promise<string[]> {
const buffer = await getFileAsBuffer(file);
const doc = await pdfjs.getDocument(buffer).promise;
const pages: Promise<string>[] = [];
for (let i = 1; i <= doc.numPages; i++) {
const page = await doc.getPage(i);
const viewport = page.getViewport({ scale: 1.5 });
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
canvas.width = viewport.width;
canvas.height = viewport.height;
if (!ctx) {
throw new Error('Failed to get 2D context from canvas');
}
const task = page.render({ canvasContext: ctx, viewport: viewport });
pages.push(
task.promise.then(() => {
return canvas.toDataURL();
})
);
}
return await Promise.all(pages);
}
// WARN: vibe code below // WARN: vibe code below
// This code is a heuristic to determine if a string is likely not binary. // This code is a heuristic to determine if a string is likely not binary.
// It is necessary because input file can have various mime types which we don't have time to investigate. // It is necessary because input file can have various mime types which we don't have time to investigate.
// For example, a python file can be text/plain, application/x-python, etc. // For example, a python file can be text/plain, application/x-python, etc.
export function isLikelyNotBinary(str: string): boolean { function isLikelyNotBinary(str: string): boolean {
const options = { const options = {
prefixLength: 1024 * 10, // Check the first 10KB of the string prefixLength: 1024 * 10, // Check the first 10KB of the string
suspiciousCharThresholdRatio: 0.15, // Allow up to 15% suspicious chars suspiciousCharThresholdRatio: 0.15, // Allow up to 15% suspicious chars

View file

@ -3,11 +3,11 @@ import react from '@vitejs/plugin-react';
import { viteSingleFile } from 'vite-plugin-singlefile'; import { viteSingleFile } from 'vite-plugin-singlefile';
import path from 'node:path'; import path from 'node:path';
import fs from 'node:fs'; import fs from 'node:fs';
import zlib from 'node:zlib'; import * as fflate from 'fflate';
/* eslint-disable */ /* eslint-disable */
const MAX_BUNDLE_SIZE = 1.5 * 1024 * 1024; // only increase when absolutely necessary const MAX_BUNDLE_SIZE = 2 * 1024 * 1024; // only increase when absolutely necessary
const GUIDE_FOR_FRONTEND = ` const GUIDE_FOR_FRONTEND = `
<!-- <!--
@ -33,9 +33,10 @@ const BUILD_PLUGINS = [
}, },
writeBundle() { writeBundle() {
const outputIndexHtml = path.join(config.build.outDir, 'index.html'); const outputIndexHtml = path.join(config.build.outDir, 'index.html');
const content = let content =
GUIDE_FOR_FRONTEND + '\n' + fs.readFileSync(outputIndexHtml, 'utf-8'); GUIDE_FOR_FRONTEND + '\n' + fs.readFileSync(outputIndexHtml, 'utf-8');
const compressed = zlib.gzipSync(Buffer.from(content, 'utf-8'), { content = content.replace(/\r/g, ''); // remove windows-style line endings
const compressed = fflate.gzipSync(Buffer.from(content, 'utf-8'), {
level: 9, level: 9,
}); });