mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # common/CMakeLists.txt # docs/backend/SYCL.md # ggml/CMakeLists.txt # ggml/src/ggml-sycl/CMakeLists.txt # ggml/src/ggml-sycl/binbcast.cpp # ggml/src/ggml-sycl/convert.cpp # ggml/src/ggml-sycl/dequantize.hpp # ggml/src/ggml-sycl/dmmv.cpp # ggml/src/ggml-sycl/gemm.hpp # ggml/src/ggml-sycl/ggml-sycl.cpp # ggml/src/ggml-sycl/mmvq.cpp # ggml/src/ggml-sycl/vecdotq.hpp # ggml/src/ggml-vulkan/CMakeLists.txt # ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt # ggml/src/gguf.cpp # scripts/compare-llama-bench.py # tests/CMakeLists.txt # tests/test-chat.cpp # tools/llama-bench/llama-bench.cpp # tools/server/README.md
This commit is contained in:
commit
e5d26a2356
47 changed files with 2671 additions and 504 deletions
240
common/chat.cpp
240
common/chat.cpp
|
@ -6,6 +6,15 @@
|
||||||
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
|
||||||
|
static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) {
|
||||||
|
auto time = std::chrono::system_clock::to_time_t(now);
|
||||||
|
auto local_time = *std::localtime(&time);
|
||||||
|
std::ostringstream ss;
|
||||||
|
ss << std::put_time(&local_time, format.c_str());
|
||||||
|
auto res = ss.str();
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
typedef minja::chat_template common_chat_template;
|
typedef minja::chat_template common_chat_template;
|
||||||
|
|
||||||
struct common_chat_templates {
|
struct common_chat_templates {
|
||||||
|
@ -24,6 +33,7 @@ struct templates_params {
|
||||||
std::string grammar;
|
std::string grammar;
|
||||||
bool add_generation_prompt = true;
|
bool add_generation_prompt = true;
|
||||||
bool extract_reasoning = true;
|
bool extract_reasoning = true;
|
||||||
|
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||||
};
|
};
|
||||||
|
|
||||||
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
|
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
|
||||||
|
@ -939,78 +949,83 @@ static void expect_tool_parameters(const std::string & name, const json & parame
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
|
static common_chat_params common_chat_params_init_llama_3_x(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
|
||||||
auto builtin_tools = json::array();
|
auto builtin_tools = json::array();
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
if (!inputs.tools.is_null()) {
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||||
std::vector<std::string> tool_rules;
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
|
std::vector<std::string> tool_rules;
|
||||||
|
|
||||||
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
|
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
|
||||||
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
|
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
|
||||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
|
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
|
||||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
|
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
|
||||||
expect_tool_parameters(name, parameters, {"query"});
|
expect_tool_parameters(name, parameters, {"query"});
|
||||||
} else if (name == "python" || name == "code_interpreter") {
|
} else if (name == "python" || name == "code_interpreter") {
|
||||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
|
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
|
||||||
expect_tool_parameters(name, parameters, {"code"});
|
expect_tool_parameters(name, parameters, {"code"});
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> kvs;
|
||||||
|
for (const auto & [key, value] : parameters.at("properties").items()) {
|
||||||
|
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
|
||||||
|
}
|
||||||
|
|
||||||
|
tool_rules.push_back(
|
||||||
|
builder.add_rule(
|
||||||
|
name + "-call",
|
||||||
|
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
|
||||||
|
builtin_tools.push_back(name);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
foreach_function(inputs.tools, [&](const json & tool) {
|
||||||
|
const auto & function = tool.at("function");
|
||||||
|
std::string name = function.at("name");
|
||||||
|
auto parameters = function.at("parameters");
|
||||||
|
builder.resolve_refs(parameters);
|
||||||
|
|
||||||
|
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
|
||||||
|
if (allow_python_tag_builtin_tools) {
|
||||||
|
handle_builtin_tool(name, parameters);
|
||||||
|
}
|
||||||
|
tool_rules.push_back(
|
||||||
|
builder.add_rule(
|
||||||
|
name + "-call",
|
||||||
|
"\"{\" space "
|
||||||
|
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
|
||||||
|
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
|
||||||
|
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
|
||||||
|
"\"}\" space"));
|
||||||
|
});
|
||||||
|
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
|
||||||
|
data.grammar_triggers.push_back({
|
||||||
|
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
||||||
|
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
|
||||||
|
});
|
||||||
|
if (!builtin_tools.empty()) {
|
||||||
|
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
||||||
|
data.preserved_tokens.push_back("<|python_tag|>");
|
||||||
}
|
}
|
||||||
|
// Allow a few empty lines on top of the usual constrained json schema space rule.
|
||||||
std::vector<std::string> kvs;
|
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||||
for (const auto & [key, value] : parameters.at("properties").items()) {
|
data.additional_stops.push_back("<|eom_id|>");
|
||||||
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
|
|
||||||
}
|
|
||||||
|
|
||||||
tool_rules.push_back(
|
|
||||||
builder.add_rule(
|
|
||||||
name + "-call",
|
|
||||||
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
|
|
||||||
builtin_tools.push_back(name);
|
|
||||||
|
|
||||||
return true;
|
|
||||||
};
|
|
||||||
|
|
||||||
foreach_function(inputs.tools, [&](const json & tool) {
|
|
||||||
const auto & function = tool.at("function");
|
|
||||||
std::string name = function.at("name");
|
|
||||||
auto parameters = function.at("parameters");
|
|
||||||
builder.resolve_refs(parameters);
|
|
||||||
|
|
||||||
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
|
|
||||||
if (allow_python_tag_builtin_tools) {
|
|
||||||
handle_builtin_tool(name, parameters);
|
|
||||||
}
|
|
||||||
tool_rules.push_back(
|
|
||||||
builder.add_rule(
|
|
||||||
name + "-call",
|
|
||||||
"\"{\" space "
|
|
||||||
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
|
|
||||||
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
|
|
||||||
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
|
|
||||||
"\"}\" space"));
|
|
||||||
});
|
});
|
||||||
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
|
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
|
||||||
data.grammar_triggers.push_back({
|
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
|
||||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
: COMMON_CHAT_FORMAT_LLAMA_3_X;
|
||||||
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
|
} else {
|
||||||
});
|
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||||
if (!builtin_tools.empty()) {
|
}
|
||||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
|
||||||
data.preserved_tokens.push_back("<|python_tag|>");
|
|
||||||
}
|
|
||||||
// Allow a few empty lines on top of the usual constrained json schema space rule.
|
|
||||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
|
||||||
});
|
|
||||||
data.additional_stops.push_back("<|eom_id|>");
|
|
||||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
|
||||||
|
{"date_string", format_time(inputs.now, "%d %b %Y")},
|
||||||
{"tools_in_user_message", false},
|
{"tools_in_user_message", false},
|
||||||
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
||||||
});
|
});
|
||||||
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
|
|
||||||
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
|
|
||||||
: COMMON_CHAT_FORMAT_LLAMA_3_X;
|
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
|
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
|
||||||
|
@ -1150,7 +1165,7 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
|
||||||
LOG_DBG("%s\n", __func__);
|
LOG_DBG("%s\n", __func__);
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
||||||
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
|
||||||
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
||||||
});
|
});
|
||||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||||
|
@ -1285,55 +1300,59 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
|
||||||
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
|
|
||||||
std::string python_code_argument_name;
|
|
||||||
auto has_raw_python = false;
|
|
||||||
|
|
||||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
if (!inputs.tools.is_null()) {
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
std::string python_code_argument_name;
|
||||||
std::vector<std::string> tool_rules;
|
auto has_raw_python = false;
|
||||||
foreach_function(inputs.tools, [&](const json & tool) {
|
|
||||||
const auto & function = tool.at("function");
|
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||||
const auto & parameters = function.at("parameters");
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::string name = function.at("name");
|
std::vector<std::string> tool_rules;
|
||||||
if (name == "python" || name == "ipython") {
|
foreach_function(inputs.tools, [&](const json & tool) {
|
||||||
if (!parameters.contains("type")) {
|
const auto & function = tool.at("function");
|
||||||
throw std::runtime_error("Missing type in python tool");
|
const auto & parameters = function.at("parameters");
|
||||||
}
|
std::string name = function.at("name");
|
||||||
has_raw_python = true;
|
if (name == "python" || name == "ipython") {
|
||||||
const auto & type = parameters.at("type");
|
if (!parameters.contains("type")) {
|
||||||
if (type == "object") {
|
throw std::runtime_error("Missing type in python tool");
|
||||||
auto properties = parameters.at("properties");
|
}
|
||||||
for (auto it = properties.begin(); it != properties.end(); ++it) {
|
has_raw_python = true;
|
||||||
if (it.value().at("type") == "string") {
|
const auto & type = parameters.at("type");
|
||||||
if (!python_code_argument_name.empty()) {
|
if (type == "object") {
|
||||||
throw std::runtime_error("Multiple string arguments found in python tool");
|
auto properties = parameters.at("properties");
|
||||||
|
for (auto it = properties.begin(); it != properties.end(); ++it) {
|
||||||
|
if (it.value().at("type") == "string") {
|
||||||
|
if (!python_code_argument_name.empty()) {
|
||||||
|
throw std::runtime_error("Multiple string arguments found in python tool");
|
||||||
|
}
|
||||||
|
python_code_argument_name = it.key();
|
||||||
}
|
}
|
||||||
python_code_argument_name = it.key();
|
|
||||||
}
|
}
|
||||||
|
if (python_code_argument_name.empty()) {
|
||||||
|
throw std::runtime_error("No string argument found in python tool");
|
||||||
|
}
|
||||||
|
} else if (type != "string") {
|
||||||
|
throw std::runtime_error("Invalid type in python tool: " + type.dump());
|
||||||
}
|
}
|
||||||
if (python_code_argument_name.empty()) {
|
|
||||||
throw std::runtime_error("No string argument found in python tool");
|
|
||||||
}
|
|
||||||
} else if (type != "string") {
|
|
||||||
throw std::runtime_error("Invalid type in python tool: " + type.dump());
|
|
||||||
}
|
}
|
||||||
|
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
|
||||||
|
});
|
||||||
|
if (has_raw_python) {
|
||||||
|
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
||||||
|
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
||||||
|
data.preserved_tokens.push_back("<|python_tag|>");
|
||||||
}
|
}
|
||||||
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
|
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
||||||
|
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||||
|
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
|
||||||
});
|
});
|
||||||
if (has_raw_python) {
|
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
|
||||||
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
} else {
|
||||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||||
data.preserved_tokens.push_back("<|python_tag|>");
|
}
|
||||||
}
|
|
||||||
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
|
||||||
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
|
||||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
|
|
||||||
});
|
|
||||||
|
|
||||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
// TODO: if (has_raw_python)
|
// TODO: if (has_raw_python)
|
||||||
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
|
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
|
static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
|
||||||
|
@ -1593,6 +1612,7 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||||
params.extract_reasoning = inputs.extract_reasoning;
|
params.extract_reasoning = inputs.extract_reasoning;
|
||||||
params.tool_choice = inputs.tool_choice;
|
params.tool_choice = inputs.tool_choice;
|
||||||
params.grammar = inputs.grammar;
|
params.grammar = inputs.grammar;
|
||||||
|
params.now = inputs.now;
|
||||||
if (!inputs.json_schema.empty()) {
|
if (!inputs.json_schema.empty()) {
|
||||||
params.json_schema = json::parse(inputs.json_schema);
|
params.json_schema = json::parse(inputs.json_schema);
|
||||||
}
|
}
|
||||||
|
@ -1644,21 +1664,21 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||||
return common_chat_params_init_firefunction_v2(tmpl, params);
|
return common_chat_params_init_firefunction_v2(tmpl, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Plain handler (no tools)
|
|
||||||
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
|
||||||
return common_chat_params_init_without_tools(tmpl, params);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Functionary v3.1 (w/ tools)
|
// Functionary v3.1 (w/ tools)
|
||||||
if (src.find("<|start_header_id|>") != std::string::npos
|
if (src.find("<|start_header_id|>") != std::string::npos
|
||||||
&& src.find("<function=") != std::string::npos) {
|
&& src.find("<function=") != std::string::npos) {
|
||||||
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
|
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Llama 3.1, 3.2, 3.3 (w/ tools)
|
// Llama 3.1, 3.2, 3.3 (also requires date_string so using it even w/o tools)
|
||||||
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||||
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
||||||
return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
|
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Plain handler (no tools)
|
||||||
|
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||||
|
return common_chat_params_init_without_tools(tmpl, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mistral Nemo (w/ tools)
|
// Mistral Nemo (w/ tools)
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
#include <chrono>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
@ -71,6 +72,7 @@ struct common_chat_templates_inputs {
|
||||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||||
bool parallel_tool_calls = false;
|
bool parallel_tool_calls = false;
|
||||||
bool extract_reasoning = true;
|
bool extract_reasoning = true;
|
||||||
|
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_chat_params {
|
struct common_chat_params {
|
||||||
|
|
|
@ -451,6 +451,25 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
||||||
s = std::move(builder);
|
s = std::move(builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
|
||||||
|
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
||||||
|
}
|
||||||
|
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
|
||||||
|
if (!str.empty() && !stop.empty()) {
|
||||||
|
const char text_last_char = str.back();
|
||||||
|
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
|
||||||
|
if (stop[char_index] == text_last_char) {
|
||||||
|
const auto current_partial = stop.substr(0, char_index + 1);
|
||||||
|
if (string_ends_with(str, current_partial)) {
|
||||||
|
return str.size() - char_index - 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::string::npos;
|
||||||
|
}
|
||||||
|
|
||||||
std::string regex_escape(const std::string & s) {
|
std::string regex_escape(const std::string & s) {
|
||||||
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
||||||
return std::regex_replace(s, special_chars, "\\$0");
|
return std::regex_replace(s, special_chars, "\\$0");
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <string_view>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
@ -499,10 +500,9 @@ static bool string_starts_with(const std::string & str,
|
||||||
return str.rfind(prefix, 0) == 0;
|
return str.rfind(prefix, 0) == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool string_ends_with(const std::string & str,
|
// While we wait for C++20's std::string::ends_with...
|
||||||
const std::string & suffix) { // While we wait for C++20's std::string::ends_with...
|
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
|
||||||
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
|
||||||
}
|
|
||||||
|
|
||||||
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
||||||
void string_process_escapes(std::string & input);
|
void string_process_escapes(std::string & input);
|
||||||
|
|
|
@ -13,10 +13,12 @@
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
#include <ctime>
|
||||||
#include <exception>
|
#include <exception>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
@ -393,8 +395,8 @@ class chat_template {
|
||||||
|
|
||||||
for (const auto & message_ : adjusted_messages) {
|
for (const auto & message_ : adjusted_messages) {
|
||||||
auto message = message_;
|
auto message = message_;
|
||||||
if (!message.contains("role") || !message.contains("content")) {
|
if (!message.contains("role") || (!message.contains("content") && !message.contains("tool_calls"))) {
|
||||||
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
throw std::runtime_error("message must have 'role' and one of 'content' or 'tool_calls' fields: " + message.dump());
|
||||||
}
|
}
|
||||||
std::string role = message.at("role");
|
std::string role = message.at("role");
|
||||||
|
|
||||||
|
@ -415,7 +417,6 @@ class chat_template {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (polyfill_tool_calls) {
|
if (polyfill_tool_calls) {
|
||||||
auto content = message.at("content");
|
|
||||||
auto tool_calls = json::array();
|
auto tool_calls = json::array();
|
||||||
for (const auto & tool_call : message.at("tool_calls")) {
|
for (const auto & tool_call : message.at("tool_calls")) {
|
||||||
if (tool_call.at("type") != "function") {
|
if (tool_call.at("type") != "function") {
|
||||||
|
@ -434,8 +435,11 @@ class chat_template {
|
||||||
auto obj = json {
|
auto obj = json {
|
||||||
{"tool_calls", tool_calls},
|
{"tool_calls", tool_calls},
|
||||||
};
|
};
|
||||||
if (!content.is_null() && !content.empty()) {
|
if (message.contains("content")) {
|
||||||
obj["content"] = content;
|
auto content = message.at("content");
|
||||||
|
if (!content.is_null() && !content.empty()) {
|
||||||
|
obj["content"] = content;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
message["content"] = obj.dump(2);
|
message["content"] = obj.dump(2);
|
||||||
message.erase("tool_calls");
|
message.erase("tool_calls");
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cctype>
|
#include <cctype>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <exception>
|
#include <exception>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
@ -233,7 +234,7 @@ public:
|
||||||
}
|
}
|
||||||
} else if (is_object()) {
|
} else if (is_object()) {
|
||||||
if (!index.is_hashable())
|
if (!index.is_hashable())
|
||||||
throw std::runtime_error("Unashable type: " + index.dump());
|
throw std::runtime_error("Unhashable type: " + index.dump());
|
||||||
auto it = object_->find(index.primitive_);
|
auto it = object_->find(index.primitive_);
|
||||||
if (it == object_->end())
|
if (it == object_->end())
|
||||||
throw std::runtime_error("Key not found: " + index.dump());
|
throw std::runtime_error("Key not found: " + index.dump());
|
||||||
|
@ -252,7 +253,7 @@ public:
|
||||||
auto index = key.get<int>();
|
auto index = key.get<int>();
|
||||||
return array_->at(index < 0 ? array_->size() + index : index);
|
return array_->at(index < 0 ? array_->size() + index : index);
|
||||||
} else if (object_) {
|
} else if (object_) {
|
||||||
if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
|
if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
|
||||||
auto it = object_->find(key.primitive_);
|
auto it = object_->find(key.primitive_);
|
||||||
if (it == object_->end()) return Value();
|
if (it == object_->end()) return Value();
|
||||||
return it->second;
|
return it->second;
|
||||||
|
@ -261,7 +262,7 @@ public:
|
||||||
}
|
}
|
||||||
void set(const Value& key, const Value& value) {
|
void set(const Value& key, const Value& value) {
|
||||||
if (!object_) throw std::runtime_error("Value is not an object: " + dump());
|
if (!object_) throw std::runtime_error("Value is not an object: " + dump());
|
||||||
if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
|
if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
|
||||||
(*object_)[key.primitive_] = value;
|
(*object_)[key.primitive_] = value;
|
||||||
}
|
}
|
||||||
Value call(const std::shared_ptr<Context> & context, ArgumentsValue & args) const {
|
Value call(const std::shared_ptr<Context> & context, ArgumentsValue & args) const {
|
||||||
|
@ -398,7 +399,7 @@ public:
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
} else if (object_) {
|
} else if (object_) {
|
||||||
if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump());
|
if (!value.is_hashable()) throw std::runtime_error("Unhashable type: " + value.dump());
|
||||||
return object_->find(value.primitive_) != object_->end();
|
return object_->find(value.primitive_) != object_->end();
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("contains can only be called on arrays and objects: " + dump());
|
throw std::runtime_error("contains can only be called on arrays and objects: " + dump());
|
||||||
|
@ -416,7 +417,7 @@ public:
|
||||||
return const_cast<Value*>(this)->at(index);
|
return const_cast<Value*>(this)->at(index);
|
||||||
}
|
}
|
||||||
Value& at(const Value & index) {
|
Value& at(const Value & index) {
|
||||||
if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
|
if (!index.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
|
||||||
if (is_array()) return array_->at(index.get<int>());
|
if (is_array()) return array_->at(index.get<int>());
|
||||||
if (is_object()) return object_->at(index.primitive_);
|
if (is_object()) return object_->at(index.primitive_);
|
||||||
throw std::runtime_error("Value is not an array or object: " + dump());
|
throw std::runtime_error("Value is not an array or object: " + dump());
|
||||||
|
@ -676,8 +677,8 @@ public:
|
||||||
class VariableExpr : public Expression {
|
class VariableExpr : public Expression {
|
||||||
std::string name;
|
std::string name;
|
||||||
public:
|
public:
|
||||||
VariableExpr(const Location & location, const std::string& n)
|
VariableExpr(const Location & loc, const std::string& n)
|
||||||
: Expression(location), name(n) {}
|
: Expression(loc), name(n) {}
|
||||||
std::string get_name() const { return name; }
|
std::string get_name() const { return name; }
|
||||||
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
|
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
|
||||||
if (!context->contains(name)) {
|
if (!context->contains(name)) {
|
||||||
|
@ -1200,9 +1201,9 @@ public:
|
||||||
|
|
||||||
class SliceExpr : public Expression {
|
class SliceExpr : public Expression {
|
||||||
public:
|
public:
|
||||||
std::shared_ptr<Expression> start, end;
|
std::shared_ptr<Expression> start, end, step;
|
||||||
SliceExpr(const Location & loc, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e)
|
SliceExpr(const Location & loc, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e, std::shared_ptr<Expression> && st = nullptr)
|
||||||
: Expression(loc), start(std::move(s)), end(std::move(e)) {}
|
: Expression(loc), start(std::move(s)), end(std::move(e)), step(std::move(st)) {}
|
||||||
Value do_evaluate(const std::shared_ptr<Context> &) const override {
|
Value do_evaluate(const std::shared_ptr<Context> &) const override {
|
||||||
throw std::runtime_error("SliceExpr not implemented");
|
throw std::runtime_error("SliceExpr not implemented");
|
||||||
}
|
}
|
||||||
|
@ -1219,18 +1220,35 @@ public:
|
||||||
if (!index) throw std::runtime_error("SubscriptExpr.index is null");
|
if (!index) throw std::runtime_error("SubscriptExpr.index is null");
|
||||||
auto target_value = base->evaluate(context);
|
auto target_value = base->evaluate(context);
|
||||||
if (auto slice = dynamic_cast<SliceExpr*>(index.get())) {
|
if (auto slice = dynamic_cast<SliceExpr*>(index.get())) {
|
||||||
auto start = slice->start ? slice->start->evaluate(context).get<int64_t>() : 0;
|
auto len = target_value.size();
|
||||||
auto end = slice->end ? slice->end->evaluate(context).get<int64_t>() : (int64_t) target_value.size();
|
auto wrap = [len](int64_t i) -> int64_t {
|
||||||
|
if (i < 0) {
|
||||||
|
return i + len;
|
||||||
|
}
|
||||||
|
return i;
|
||||||
|
};
|
||||||
|
int64_t step = slice->step ? slice->step->evaluate(context).get<int64_t>() : 1;
|
||||||
|
if (!step) {
|
||||||
|
throw std::runtime_error("slice step cannot be zero");
|
||||||
|
}
|
||||||
|
int64_t start = slice->start ? wrap(slice->start->evaluate(context).get<int64_t>()) : (step < 0 ? len - 1 : 0);
|
||||||
|
int64_t end = slice->end ? wrap(slice->end->evaluate(context).get<int64_t>()) : (step < 0 ? -1 : len);
|
||||||
if (target_value.is_string()) {
|
if (target_value.is_string()) {
|
||||||
std::string s = target_value.get<std::string>();
|
std::string s = target_value.get<std::string>();
|
||||||
if (start < 0) start = s.size() + start;
|
|
||||||
if (end < 0) end = s.size() + end;
|
std::string result;
|
||||||
return s.substr(start, end - start);
|
if (start < end && step == 1) {
|
||||||
|
result = s.substr(start, end - start);
|
||||||
|
} else {
|
||||||
|
for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
|
||||||
|
result += s[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
|
||||||
} else if (target_value.is_array()) {
|
} else if (target_value.is_array()) {
|
||||||
if (start < 0) start = target_value.size() + start;
|
|
||||||
if (end < 0) end = target_value.size() + end;
|
|
||||||
auto result = Value::array();
|
auto result = Value::array();
|
||||||
for (auto i = start; i < end; ++i) {
|
for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
|
||||||
result.push_back(target_value.at(i));
|
result.push_back(target_value.at(i));
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
@ -1305,6 +1323,8 @@ public:
|
||||||
if (name == "iterable") return l.is_iterable();
|
if (name == "iterable") return l.is_iterable();
|
||||||
if (name == "sequence") return l.is_array();
|
if (name == "sequence") return l.is_array();
|
||||||
if (name == "defined") return !l.is_null();
|
if (name == "defined") return !l.is_null();
|
||||||
|
if (name == "true") return l.to_bool();
|
||||||
|
if (name == "false") return !l.to_bool();
|
||||||
throw std::runtime_error("Unknown type for 'is' operator: " + name);
|
throw std::runtime_error("Unknown type for 'is' operator: " + name);
|
||||||
};
|
};
|
||||||
auto value = eval();
|
auto value = eval();
|
||||||
|
@ -1520,6 +1540,10 @@ public:
|
||||||
vargs.expectArgs("endswith method", {1, 1}, {0, 0});
|
vargs.expectArgs("endswith method", {1, 1}, {0, 0});
|
||||||
auto suffix = vargs.args[0].get<std::string>();
|
auto suffix = vargs.args[0].get<std::string>();
|
||||||
return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
|
return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
|
||||||
|
} else if (method->get_name() == "startswith") {
|
||||||
|
vargs.expectArgs("startswith method", {1, 1}, {0, 0});
|
||||||
|
auto prefix = vargs.args[0].get<std::string>();
|
||||||
|
return prefix.length() <= str.length() && std::equal(prefix.begin(), prefix.end(), str.begin());
|
||||||
} else if (method->get_name() == "title") {
|
} else if (method->get_name() == "title") {
|
||||||
vargs.expectArgs("title method", {0, 0}, {0, 0});
|
vargs.expectArgs("title method", {0, 0}, {0, 0});
|
||||||
auto res = str;
|
auto res = str;
|
||||||
|
@ -2082,28 +2106,37 @@ private:
|
||||||
|
|
||||||
while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
|
while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
|
||||||
if (!consumeToken("[").empty()) {
|
if (!consumeToken("[").empty()) {
|
||||||
std::shared_ptr<Expression> index;
|
std::shared_ptr<Expression> index;
|
||||||
|
auto slice_loc = get_location();
|
||||||
|
std::shared_ptr<Expression> start, end, step;
|
||||||
|
bool has_first_colon = false, has_second_colon = false;
|
||||||
|
|
||||||
|
if (!peekSymbols({ ":" })) {
|
||||||
|
start = parseExpression();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!consumeToken(":").empty()) {
|
||||||
|
has_first_colon = true;
|
||||||
|
if (!peekSymbols({ ":", "]" })) {
|
||||||
|
end = parseExpression();
|
||||||
|
}
|
||||||
if (!consumeToken(":").empty()) {
|
if (!consumeToken(":").empty()) {
|
||||||
auto slice_end = parseExpression();
|
has_second_colon = true;
|
||||||
index = std::make_shared<SliceExpr>(slice_end->location, nullptr, std::move(slice_end));
|
if (!peekSymbols({ "]" })) {
|
||||||
} else {
|
step = parseExpression();
|
||||||
auto slice_start = parseExpression();
|
|
||||||
if (!consumeToken(":").empty()) {
|
|
||||||
consumeSpaces();
|
|
||||||
if (peekSymbols({ "]" })) {
|
|
||||||
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), nullptr);
|
|
||||||
} else {
|
|
||||||
auto slice_end = parseExpression();
|
|
||||||
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), std::move(slice_end));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
index = std::move(slice_start);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!index) throw std::runtime_error("Empty index in subscript");
|
}
|
||||||
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
|
|
||||||
|
|
||||||
value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
|
if ((has_first_colon || has_second_colon) && (start || end || step)) {
|
||||||
|
index = std::make_shared<SliceExpr>(slice_loc, std::move(start), std::move(end), std::move(step));
|
||||||
|
} else {
|
||||||
|
index = std::move(start);
|
||||||
|
}
|
||||||
|
if (!index) throw std::runtime_error("Empty index in subscript");
|
||||||
|
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
|
||||||
|
|
||||||
|
value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
|
||||||
} else if (!consumeToken(".").empty()) {
|
} else if (!consumeToken(".").empty()) {
|
||||||
auto identifier = parseIdentifier();
|
auto identifier = parseIdentifier();
|
||||||
if (!identifier) throw std::runtime_error("Expected identifier in subscript");
|
if (!identifier) throw std::runtime_error("Expected identifier in subscript");
|
||||||
|
|
204
common/regex-partial.cpp
Normal file
204
common/regex-partial.cpp
Normal file
|
@ -0,0 +1,204 @@
|
||||||
|
#include "regex-partial.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include <functional>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
common_regex::common_regex(const std::string & pattern) :
|
||||||
|
pattern(pattern),
|
||||||
|
rx(pattern),
|
||||||
|
rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {}
|
||||||
|
|
||||||
|
common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const {
|
||||||
|
std::smatch match;
|
||||||
|
if (pos > input.size()) {
|
||||||
|
throw std::runtime_error("Position out of bounds");
|
||||||
|
}
|
||||||
|
auto start = input.begin() + pos;
|
||||||
|
auto found = as_match
|
||||||
|
? std::regex_match(start, input.end(), match, rx)
|
||||||
|
: std::regex_search(start, input.end(), match, rx);
|
||||||
|
if (found) {
|
||||||
|
common_regex_match res;
|
||||||
|
res.type = COMMON_REGEX_MATCH_TYPE_FULL;
|
||||||
|
for (size_t i = 0; i < match.size(); ++i) {
|
||||||
|
auto begin = pos + match.position(i);
|
||||||
|
res.groups.emplace_back(begin, begin + match.length(i));
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
std::match_results<std::string::const_reverse_iterator> srmatch;
|
||||||
|
if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) {
|
||||||
|
auto group = srmatch[1].str();
|
||||||
|
if (group.length() != 0) {
|
||||||
|
auto it = srmatch[1].second.base();
|
||||||
|
// auto position = static_cast<size_t>(std::distance(input.begin(), it));
|
||||||
|
if ((!as_match) || it == input.begin()) {
|
||||||
|
common_regex_match res;
|
||||||
|
res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
|
||||||
|
const size_t begin = std::distance(input.begin(), it);
|
||||||
|
const size_t end = input.size();
|
||||||
|
if (begin == std::string::npos || end == std::string::npos || begin > end) {
|
||||||
|
throw std::runtime_error("Invalid range");
|
||||||
|
}
|
||||||
|
res.groups.push_back({begin, end});
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern.
|
||||||
|
|
||||||
|
Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html)
|
||||||
|
to see if a string ends with a partial regex match, but but it's not in std::regex yet.
|
||||||
|
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
|
||||||
|
|
||||||
|
- /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).*
|
||||||
|
- /a|b/ -> (a|b).*
|
||||||
|
- /a*?/ -> error, could match ""
|
||||||
|
- /a*b/ -> ((?:b)?a*+).* (final repetitions become eager)
|
||||||
|
- /.*?ab/ -> ((?:b)?a).* (merge .*)
|
||||||
|
- /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches)
|
||||||
|
- /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).*
|
||||||
|
- /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).*
|
||||||
|
- /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).*
|
||||||
|
|
||||||
|
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern
|
||||||
|
(i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored)
|
||||||
|
*/
|
||||||
|
std::string regex_to_reversed_partial_regex(const std::string & pattern) {
|
||||||
|
auto it = pattern.begin();
|
||||||
|
const auto end = pattern.end();
|
||||||
|
|
||||||
|
std::function<std::string()> process = [&]() {
|
||||||
|
std::vector<std::vector<std::string>> alternatives(1);
|
||||||
|
std::vector<std::string> * sequence = &alternatives.back();
|
||||||
|
|
||||||
|
while (it != end) {
|
||||||
|
if (*it == '[') {
|
||||||
|
auto start = it;
|
||||||
|
++it;
|
||||||
|
while (it != end) {
|
||||||
|
if ((*it == '\\') && (++it != end)) {
|
||||||
|
++it;
|
||||||
|
} else if ((it != end) && (*it == ']')) {
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (it == end) {
|
||||||
|
throw std::runtime_error("Unmatched '[' in pattern");
|
||||||
|
}
|
||||||
|
++it;
|
||||||
|
sequence->push_back(std::string(start, it));
|
||||||
|
} else if (*it == '*' || *it == '?' || *it == '+') {
|
||||||
|
if (sequence->empty()) {
|
||||||
|
throw std::runtime_error("Quantifier without preceding element");
|
||||||
|
}
|
||||||
|
sequence->back() += *it;
|
||||||
|
auto is_star = *it == '*';
|
||||||
|
++it;
|
||||||
|
if (is_star) {
|
||||||
|
if (*it == '?') {
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (*it == '{') {
|
||||||
|
if (sequence->empty()) {
|
||||||
|
throw std::runtime_error("Repetition without preceding element");
|
||||||
|
}
|
||||||
|
++it;
|
||||||
|
auto start = it;
|
||||||
|
while (it != end && *it != '}') {
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
if (it == end) {
|
||||||
|
throw std::runtime_error("Unmatched '{' in pattern");
|
||||||
|
}
|
||||||
|
auto parts = string_split(std::string(start, it), ",");
|
||||||
|
++it;
|
||||||
|
if (parts.size() > 2) {
|
||||||
|
throw std::runtime_error("Invalid repetition range in pattern");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto parseOptInt = [&](const std::string & s, const std::optional<int> & def = std::nullopt) -> std::optional<int> {
|
||||||
|
if (s.empty()) {
|
||||||
|
return def;
|
||||||
|
}
|
||||||
|
return std::stoi(s);
|
||||||
|
};
|
||||||
|
auto min = parseOptInt(parts[0], 0);
|
||||||
|
auto max = parts.size() == 1 ? min : parseOptInt(parts[1]);
|
||||||
|
if (min && max && *max < *min) {
|
||||||
|
throw std::runtime_error("Invalid repetition range in pattern");
|
||||||
|
}
|
||||||
|
// Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded)
|
||||||
|
auto part = sequence->back();
|
||||||
|
sequence->pop_back();
|
||||||
|
for (int i = 0; i < *min; i++) {
|
||||||
|
sequence->push_back(part);
|
||||||
|
}
|
||||||
|
if (max) {
|
||||||
|
for (int i = *min; i < *max; i++) {
|
||||||
|
sequence->push_back(part + "?");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sequence->push_back(part + "*");
|
||||||
|
}
|
||||||
|
} else if (*it == '(') {
|
||||||
|
++it;
|
||||||
|
if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') {
|
||||||
|
it += 2;
|
||||||
|
}
|
||||||
|
auto sub = process();
|
||||||
|
if (*it != ')') {
|
||||||
|
throw std::runtime_error("Unmatched '(' in pattern");
|
||||||
|
}
|
||||||
|
++it;
|
||||||
|
auto & part = sequence->emplace_back("(?:");
|
||||||
|
part += sub;
|
||||||
|
part += ")";
|
||||||
|
} else if (*it == ')') {
|
||||||
|
break;
|
||||||
|
} else if (*it == '|') {
|
||||||
|
++it;
|
||||||
|
alternatives.emplace_back();
|
||||||
|
sequence = &alternatives.back();
|
||||||
|
} else if (*it == '\\' && (++it != end)) {
|
||||||
|
auto str = std::string("\\") + *it;
|
||||||
|
sequence->push_back(str);
|
||||||
|
++it;
|
||||||
|
} else if (it != end) {
|
||||||
|
sequence->push_back(std::string(1, *it));
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).*
|
||||||
|
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
|
||||||
|
// We'll do the outermost capturing group and final .* in the enclosing function.
|
||||||
|
std::vector<std::string> res_alts;
|
||||||
|
for (const auto & parts : alternatives) {
|
||||||
|
auto & res = res_alts.emplace_back();
|
||||||
|
for (size_t i = 0; i < parts.size() - 1; i++) {
|
||||||
|
res += "(?:";
|
||||||
|
}
|
||||||
|
for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
|
||||||
|
res += *it;
|
||||||
|
if (it != parts.rend() - 1) {
|
||||||
|
res += ")?";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return string_join(res_alts, "|");
|
||||||
|
};
|
||||||
|
auto res = process();
|
||||||
|
if (it != end) {
|
||||||
|
throw std::runtime_error("Unmatched '(' in pattern");
|
||||||
|
}
|
||||||
|
|
||||||
|
return "(" + res + ")[\\s\\S]*";
|
||||||
|
}
|
56
common/regex-partial.h
Normal file
56
common/regex-partial.h
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <regex>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
enum common_regex_match_type {
|
||||||
|
COMMON_REGEX_MATCH_TYPE_NONE,
|
||||||
|
COMMON_REGEX_MATCH_TYPE_PARTIAL,
|
||||||
|
COMMON_REGEX_MATCH_TYPE_FULL,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_string_range {
|
||||||
|
size_t begin;
|
||||||
|
size_t end;
|
||||||
|
common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
|
||||||
|
if (begin > end) {
|
||||||
|
throw std::runtime_error("Invalid range");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// prevent default ctor
|
||||||
|
common_string_range() = delete;
|
||||||
|
bool empty() const {
|
||||||
|
return begin == end;
|
||||||
|
}
|
||||||
|
bool operator==(const common_string_range & other) const {
|
||||||
|
return begin == other.begin && end == other.end;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_regex_match {
|
||||||
|
common_regex_match_type type = COMMON_REGEX_MATCH_TYPE_NONE;
|
||||||
|
std::vector<common_string_range> groups;
|
||||||
|
|
||||||
|
bool operator==(const common_regex_match & other) const {
|
||||||
|
return type == other.type && groups == other.groups;
|
||||||
|
}
|
||||||
|
bool operator!=(const common_regex_match & other) const {
|
||||||
|
return !(*this == other);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class common_regex {
|
||||||
|
std::string pattern;
|
||||||
|
std::regex rx;
|
||||||
|
std::regex rx_reversed_partial;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit common_regex(const std::string & pattern);
|
||||||
|
|
||||||
|
common_regex_match search(const std::string & input, size_t pos, bool as_match = false) const;
|
||||||
|
|
||||||
|
const std::string & str() const { return pattern; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// For testing only (pretty print of failures).
|
||||||
|
std::string regex_to_reversed_partial_regex(const std::string & pattern);
|
|
@ -2069,6 +2069,9 @@ class Llama4Model(LlamaModel):
|
||||||
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"])
|
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"])
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||||
|
if name.startswith("language_model."):
|
||||||
|
name = name.replace("language_model.", "")
|
||||||
|
|
||||||
# split the gate_up into gate and up
|
# split the gate_up into gate and up
|
||||||
if "gate_up_proj" in name:
|
if "gate_up_proj" in name:
|
||||||
name_up = name.replace("gate_up_proj", "up_proj.weight")
|
name_up = name.replace("gate_up_proj", "up_proj.weight")
|
||||||
|
|
|
@ -31,7 +31,7 @@ llama-server -hf ggml-org/gemma-3-4b-it-GGUF --no-mmproj-offload
|
||||||
|
|
||||||
## Pre-quantized models
|
## Pre-quantized models
|
||||||
|
|
||||||
These are ready-to-use models, most of them come with `Q4_K_M` quantization by default.
|
These are ready-to-use models, most of them come with `Q4_K_M` quantization by default. They can be found at the Hugging Face page of the ggml-org: https://huggingface.co/ggml-org
|
||||||
|
|
||||||
Replaces the `(tool_name)` with the name of binary you want to use. For example, `llama-mtmd-cli` or `llama-server`
|
Replaces the `(tool_name)` with the name of binary you want to use. For example, `llama-mtmd-cli` or `llama-server`
|
||||||
|
|
||||||
|
|
|
@ -8520,7 +8520,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||||
|
|
||||||
void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||||
assert(n % QK_K == 0);
|
assert(n % QK_K == 0);
|
||||||
|
#ifdef __ARM_FEATURE_MATMUL_INT8
|
||||||
|
assert((nrc == 2) || (nrc == 1));
|
||||||
|
#else
|
||||||
assert(nrc == 1);
|
assert(nrc == 1);
|
||||||
|
#endif
|
||||||
UNUSED(nrc);
|
UNUSED(nrc);
|
||||||
UNUSED(bx);
|
UNUSED(bx);
|
||||||
UNUSED(by);
|
UNUSED(by);
|
||||||
|
@ -8531,6 +8535,197 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||||
|
|
||||||
const int nb = n / QK_K;
|
const int nb = n / QK_K;
|
||||||
|
|
||||||
|
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
|
if (nrc == 2) {
|
||||||
|
const block_q6_K * GGML_RESTRICT x0 = x;
|
||||||
|
const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
|
||||||
|
const block_q8_K * GGML_RESTRICT y0 = y;
|
||||||
|
const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
|
||||||
|
|
||||||
|
float32x4_t vfsum = vdupq_n_f32(0.0f);
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
|
||||||
|
const uint8_t * GGML_RESTRICT ql0 = x0->ql;
|
||||||
|
const uint8_t * GGML_RESTRICT ql1 = x1->ql;
|
||||||
|
const uint8_t * GGML_RESTRICT qh0 = x0->qh;
|
||||||
|
const uint8_t * GGML_RESTRICT qh1 = x1->qh;
|
||||||
|
const int8_t * GGML_RESTRICT qy0 = y0->qs;
|
||||||
|
const int8_t * GGML_RESTRICT qy1 = y1->qs;
|
||||||
|
|
||||||
|
const uint8x16_t mone = vdupq_n_u8(0x30);
|
||||||
|
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||||
|
|
||||||
|
int32x4_t visum = vdupq_n_s32(0);
|
||||||
|
|
||||||
|
// process 8 blocks per iteration, totally 16 blocks
|
||||||
|
for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) {
|
||||||
|
int8x16_t vx0[8], vx1[8];
|
||||||
|
|
||||||
|
// de-quantize vx0[8]
|
||||||
|
{
|
||||||
|
const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0);
|
||||||
|
const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0);
|
||||||
|
|
||||||
|
uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
|
||||||
|
uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
|
||||||
|
uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
|
||||||
|
uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
|
||||||
|
|
||||||
|
vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
|
||||||
|
vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
|
||||||
|
vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
|
||||||
|
vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
|
||||||
|
|
||||||
|
q6h_0 = vandq_u8(mone, qh_bits.val[0]);
|
||||||
|
q6h_1 = vandq_u8(mone, qh_bits.val[1]);
|
||||||
|
q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
|
||||||
|
q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
|
||||||
|
|
||||||
|
vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
|
||||||
|
vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
|
||||||
|
vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
|
||||||
|
vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
|
||||||
|
}
|
||||||
|
|
||||||
|
// de-quantize vx1[8]
|
||||||
|
{
|
||||||
|
const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1);
|
||||||
|
const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1);
|
||||||
|
|
||||||
|
uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
|
||||||
|
uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
|
||||||
|
uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
|
||||||
|
uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
|
||||||
|
|
||||||
|
vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
|
||||||
|
vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
|
||||||
|
vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
|
||||||
|
vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
|
||||||
|
|
||||||
|
q6h_0 = vandq_u8(mone, qh_bits.val[0]);
|
||||||
|
q6h_1 = vandq_u8(mone, qh_bits.val[1]);
|
||||||
|
q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
|
||||||
|
q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
|
||||||
|
|
||||||
|
vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
|
||||||
|
vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
|
||||||
|
vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
|
||||||
|
vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
|
||||||
|
}
|
||||||
|
|
||||||
|
// process 16 elements (one block with same scale) per iteration
|
||||||
|
// - vx = concat(ql, qh) - 32
|
||||||
|
// - r1,r2,r3,r4 = smmla(vx, vy)
|
||||||
|
for (int k = 0; k < 8; ++k) {
|
||||||
|
const int blk = j * 8 + k;
|
||||||
|
|
||||||
|
const int8x16_t vy0 = vld1q_s8(qy0);
|
||||||
|
const int8x16_t vy1 = vld1q_s8(qy1);
|
||||||
|
qy0 += 16;
|
||||||
|
qy1 += 16;
|
||||||
|
|
||||||
|
const int32x4_t block_scale = {
|
||||||
|
x0->scales[blk],
|
||||||
|
x0->scales[blk],
|
||||||
|
x1->scales[blk],
|
||||||
|
x1->scales[blk],
|
||||||
|
};
|
||||||
|
|
||||||
|
// calculate four results at once with outer product
|
||||||
|
const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
|
||||||
|
const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
|
||||||
|
const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
|
||||||
|
const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
|
||||||
|
int32x4_t vr = vdupq_n_s32(0);
|
||||||
|
vr = vmmlaq_s32(vr, vx_l, vy_l);
|
||||||
|
vr = vmmlaq_s32(vr, vx_h, vy_h);
|
||||||
|
|
||||||
|
// apply block scale, will NOT overflow
|
||||||
|
// block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits
|
||||||
|
visum = vmlaq_s32(visum, vr, block_scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// adjust bias, apply superblock scale
|
||||||
|
{
|
||||||
|
int32_t bias[4];
|
||||||
|
#ifdef __ARM_FEATURE_SVE
|
||||||
|
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
|
||||||
|
const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
|
||||||
|
const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
|
||||||
|
const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
|
||||||
|
const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
|
||||||
|
const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
|
||||||
|
const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
|
||||||
|
const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
|
||||||
|
const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
|
||||||
|
const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
|
||||||
|
const svint64_t zero = svdup_n_s64(0);
|
||||||
|
bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
|
||||||
|
svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
|
||||||
|
bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
|
||||||
|
svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
|
||||||
|
bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
|
||||||
|
svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
|
||||||
|
bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
|
||||||
|
svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
|
||||||
|
#else
|
||||||
|
// NEON doesn't support int16 dot product, fallback to separated mul and add
|
||||||
|
const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
|
||||||
|
const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
|
||||||
|
|
||||||
|
int8x16_t scales_s8 = vld1q_s8(x0->scales);
|
||||||
|
const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
|
||||||
|
scales_s8 = vld1q_s8(x1->scales);
|
||||||
|
const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
|
||||||
|
|
||||||
|
int32x4_t prod;
|
||||||
|
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])),
|
||||||
|
vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))),
|
||||||
|
vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])),
|
||||||
|
vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1]))));
|
||||||
|
bias[0] = vaddvq_s32(prod);
|
||||||
|
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])),
|
||||||
|
vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))),
|
||||||
|
vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])),
|
||||||
|
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1]))));
|
||||||
|
bias[1] = vaddvq_s32(prod);
|
||||||
|
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])),
|
||||||
|
vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))),
|
||||||
|
vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])),
|
||||||
|
vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1]))));
|
||||||
|
bias[2] = vaddvq_s32(prod);
|
||||||
|
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])),
|
||||||
|
vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))),
|
||||||
|
vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])),
|
||||||
|
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
|
||||||
|
bias[3] = vaddvq_s32(prod);
|
||||||
|
|
||||||
|
#endif
|
||||||
|
const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
|
||||||
|
|
||||||
|
const float32x4_t superblock_scale = {
|
||||||
|
GGML_FP16_TO_FP32(x0->d) * y0->d,
|
||||||
|
GGML_FP16_TO_FP32(x0->d) * y1->d,
|
||||||
|
GGML_FP16_TO_FP32(x1->d) * y0->d,
|
||||||
|
GGML_FP16_TO_FP32(x1->d) * y1->d,
|
||||||
|
};
|
||||||
|
|
||||||
|
visum = vsubq_s32(visum, vibias);
|
||||||
|
vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// vfsum = ABCD -> ACBD
|
||||||
|
// AC -> s, BD -> (s+bs)
|
||||||
|
vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
|
||||||
|
vst1_f32(s, vget_low_f32 (vfsum));
|
||||||
|
vst1_f32(s + bs, vget_high_f32(vfsum));
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef __ARM_FEATURE_SVE
|
#ifdef __ARM_FEATURE_SVE
|
||||||
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
||||||
float sum = 0;
|
float sum = 0;
|
||||||
|
|
|
@ -286,7 +286,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
||||||
.from_float = quantize_row_q6_K,
|
.from_float = quantize_row_q6_K,
|
||||||
.vec_dot = ggml_vec_dot_q6_K_q8_K,
|
.vec_dot = ggml_vec_dot_q6_K_q8_K,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||||
|
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
||||||
|
.nrows = 2,
|
||||||
|
#else
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
|
#endif
|
||||||
},
|
},
|
||||||
[GGML_TYPE_IQ2_XXS] = {
|
[GGML_TYPE_IQ2_XXS] = {
|
||||||
.from_float = NULL,
|
.from_float = NULL,
|
||||||
|
|
|
@ -678,10 +678,14 @@ void launch_fattn(
|
||||||
) {
|
) {
|
||||||
constexpr int ncols = ncols1 * ncols2;
|
constexpr int ncols = ncols1 * ncols2;
|
||||||
|
|
||||||
|
const bool is_mla = DV == 512; // TODO better parameterization
|
||||||
|
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
const ggml_tensor * K = dst->src[1];
|
const ggml_tensor * K = dst->src[1];
|
||||||
const ggml_tensor * V = dst->src[2];
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
|
||||||
|
GGML_ASSERT(V || is_mla);
|
||||||
|
|
||||||
const ggml_tensor * mask = dst->src[3];
|
const ggml_tensor * mask = dst->src[3];
|
||||||
|
|
||||||
ggml_tensor * KQV = dst;
|
ggml_tensor * KQV = dst;
|
||||||
|
@ -689,6 +693,10 @@ void launch_fattn(
|
||||||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
GGML_ASSERT( Q->nb[0] == ggml_element_size(Q));
|
||||||
|
GGML_ASSERT( K->nb[0] == ggml_element_size(K));
|
||||||
|
GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
|
||||||
|
|
||||||
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
||||||
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
||||||
|
@ -713,10 +721,10 @@ void launch_fattn(
|
||||||
size_t nb12 = K->nb[2];
|
size_t nb12 = K->nb[2];
|
||||||
size_t nb13 = K->nb[3];
|
size_t nb13 = K->nb[3];
|
||||||
|
|
||||||
const char * V_data = (const char *) V->data;
|
const char * V_data = V ? (const char *) V->data : nullptr;
|
||||||
size_t nb21 = V->nb[1];
|
size_t nb21 = V ? V->nb[1] : nb11;
|
||||||
size_t nb22 = V->nb[2];
|
size_t nb22 = V ? V->nb[2] : nb12;
|
||||||
size_t nb23 = V->nb[3];
|
size_t nb23 = V ? V->nb[3] : nb13;
|
||||||
|
|
||||||
if (need_f16_K && K->type != GGML_TYPE_F16) {
|
if (need_f16_K && K->type != GGML_TYPE_F16) {
|
||||||
GGML_ASSERT(ggml_is_contiguously_allocated(K));
|
GGML_ASSERT(ggml_is_contiguously_allocated(K));
|
||||||
|
@ -733,7 +741,7 @@ void launch_fattn(
|
||||||
nb13 = nb13*bs*sizeof(half)/ts;
|
nb13 = nb13*bs*sizeof(half)/ts;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (need_f16_V && V->type != GGML_TYPE_F16) {
|
if (V && need_f16_V && V->type != GGML_TYPE_F16) {
|
||||||
GGML_ASSERT(ggml_is_contiguously_allocated(V));
|
GGML_ASSERT(ggml_is_contiguously_allocated(V));
|
||||||
V_f16.alloc(ggml_nelements(V));
|
V_f16.alloc(ggml_nelements(V));
|
||||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
||||||
|
|
|
@ -33,9 +33,30 @@ struct fattn_mma_f16_config< 64, 64> {
|
||||||
static constexpr int nwarps_max = 4;
|
static constexpr int nwarps_max = 4;
|
||||||
static constexpr bool Q_in_reg = true;
|
static constexpr bool Q_in_reg = true;
|
||||||
static constexpr int nstages_target = 2;
|
static constexpr int nstages_target = 2;
|
||||||
static constexpr int nbatch_K2 = 32;
|
|
||||||
static constexpr int nbatch_V2 = 32;
|
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
static constexpr int nbatch_combine = 32;
|
return 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||||
|
return 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
|
return 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||||
|
return 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
|
return 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||||
|
return 32;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -44,9 +65,30 @@ struct fattn_mma_f16_config< 80, 80> {
|
||||||
static constexpr int nwarps_max = 4;
|
static constexpr int nwarps_max = 4;
|
||||||
static constexpr bool Q_in_reg = true;
|
static constexpr bool Q_in_reg = true;
|
||||||
static constexpr int nstages_target = 2;
|
static constexpr int nstages_target = 2;
|
||||||
static constexpr int nbatch_K2 = 40;
|
|
||||||
static constexpr int nbatch_V2 = 40;
|
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
static constexpr int nbatch_combine = 40;
|
return 40;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||||
|
return 40;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
|
return 40;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||||
|
return 40;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
|
return 40;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||||
|
return 40;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -55,9 +97,30 @@ struct fattn_mma_f16_config< 96, 96> {
|
||||||
static constexpr int nwarps_max = 4;
|
static constexpr int nwarps_max = 4;
|
||||||
static constexpr bool Q_in_reg = true;
|
static constexpr bool Q_in_reg = true;
|
||||||
static constexpr int nstages_target = 2;
|
static constexpr int nstages_target = 2;
|
||||||
static constexpr int nbatch_K2 = 48;
|
|
||||||
static constexpr int nbatch_V2 = 48;
|
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
static constexpr int nbatch_combine = 48;
|
return 48;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||||
|
return 48;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
|
return 48;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||||
|
return 48;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
|
return 48;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||||
|
return 48;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -66,9 +129,30 @@ struct fattn_mma_f16_config<112, 112> {
|
||||||
static constexpr int nwarps_max = 4;
|
static constexpr int nwarps_max = 4;
|
||||||
static constexpr bool Q_in_reg = true;
|
static constexpr bool Q_in_reg = true;
|
||||||
static constexpr int nstages_target = 2;
|
static constexpr int nstages_target = 2;
|
||||||
static constexpr int nbatch_K2 = 56;
|
|
||||||
static constexpr int nbatch_V2 = 56;
|
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
static constexpr int nbatch_combine = 56;
|
return 56;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||||
|
return 56;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
|
return 56;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||||
|
return 56;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
|
return 56;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||||
|
return 56;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -77,9 +161,30 @@ struct fattn_mma_f16_config<128, 128> {
|
||||||
static constexpr int nwarps_max = 4;
|
static constexpr int nwarps_max = 4;
|
||||||
static constexpr bool Q_in_reg = true;
|
static constexpr bool Q_in_reg = true;
|
||||||
static constexpr int nstages_target = 2;
|
static constexpr int nstages_target = 2;
|
||||||
static constexpr int nbatch_K2 = 64;
|
|
||||||
static constexpr int nbatch_V2 = 64;
|
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
static constexpr int nbatch_combine = 64;
|
return 64;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||||
|
return 64;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
|
return 64;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||||
|
return 64;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
|
return 64;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||||
|
return 64;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -88,9 +193,38 @@ struct fattn_mma_f16_config<256, 256> {
|
||||||
static constexpr int nwarps_max = 4;
|
static constexpr int nwarps_max = 4;
|
||||||
static constexpr bool Q_in_reg = true;
|
static constexpr bool Q_in_reg = true;
|
||||||
static constexpr int nstages_target = 2;
|
static constexpr int nstages_target = 2;
|
||||||
static constexpr int nbatch_K2 = 128;
|
|
||||||
static constexpr int nbatch_V2 = 128;
|
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
static constexpr int nbatch_combine = 128;
|
return 128;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||||
|
return 128;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
|
return 128;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||||
|
return 128;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int get_nbatch_combine_host(const int cc, const int ncols) {
|
||||||
|
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
|
||||||
|
return ncols <= 16 ? 128 : 64;
|
||||||
|
}
|
||||||
|
return 64;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_combine_device(int ncols) {
|
||||||
|
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||||
|
return ncols <= 16 ? 128 : 64;
|
||||||
|
#else
|
||||||
|
GGML_UNUSED(ncols);
|
||||||
|
return 128;
|
||||||
|
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -99,9 +233,44 @@ struct fattn_mma_f16_config<576, 512> {
|
||||||
static constexpr int nwarps_max = 8;
|
static constexpr int nwarps_max = 8;
|
||||||
static constexpr bool Q_in_reg = false;
|
static constexpr bool Q_in_reg = false;
|
||||||
static constexpr int nstages_target = 1;
|
static constexpr int nstages_target = 1;
|
||||||
static constexpr int nbatch_K2 = 160;
|
|
||||||
static constexpr int nbatch_V2 = 128;
|
static int get_nbatch_K2_host(const int cc, const int ncols) {
|
||||||
static constexpr int nbatch_combine = 128;
|
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
|
||||||
|
return ncols <= 16 ? 96 : 160;
|
||||||
|
}
|
||||||
|
return ncols <= 16 ? 288 : 160;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_K2_device(int ncols) {
|
||||||
|
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||||
|
return ncols <= 16 ? 96 : 160;
|
||||||
|
#else
|
||||||
|
return ncols <= 16 ? 288 : 160;
|
||||||
|
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||||
|
}
|
||||||
|
|
||||||
|
static int get_nbatch_V2_host(const int cc, const int ncols) {
|
||||||
|
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
|
||||||
|
return ncols <= 16 ? 64 : 128;
|
||||||
|
}
|
||||||
|
return ncols <= 16 ? 256 : 128;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_V2_device(int ncols) {
|
||||||
|
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||||
|
return ncols <= 16 ? 64 : 128;
|
||||||
|
#else
|
||||||
|
return ncols <= 16 ? 256 : 128;
|
||||||
|
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||||
|
}
|
||||||
|
|
||||||
|
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
|
return 128;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||||
|
return 128;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// ------------------------------------------------------------------------------------------------------------------
|
// ------------------------------------------------------------------------------------------------------------------
|
||||||
|
@ -120,7 +289,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||||
|
|
||||||
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
|
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
|
||||||
|
|
||||||
auto load = [&] __device__ (const int n) {
|
auto load = [&] __device__ (auto n) {
|
||||||
const int stride_k = WARP_SIZE >> n;
|
const int stride_k = WARP_SIZE >> n;
|
||||||
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
||||||
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
|
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
|
||||||
|
@ -223,7 +392,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
|
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
|
||||||
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||||
const float2 * const __restrict__ Q_f2,
|
const float2 * const __restrict__ Q_f2,
|
||||||
const half2 * const __restrict__ K_h2,
|
const half2 * const __restrict__ K_h2,
|
||||||
|
@ -261,10 +430,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||||
constexpr int cols_per_warp = ntiles * tile_B::I;
|
constexpr int cols_per_warp = ntiles * tile_B::I;
|
||||||
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
||||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||||
|
constexpr int ncols = ncols1 * ncols2;
|
||||||
|
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
|
||||||
|
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
|
||||||
|
|
||||||
constexpr int stride_tile_Q = DKQ/2 + 4;
|
constexpr int stride_tile_Q = DKQ/2 + 4;
|
||||||
constexpr int stride_tile_K = c::nbatch_K2 + 4;
|
constexpr int stride_tile_K = nbatch_K2 + 4;
|
||||||
constexpr int stride_tile_V = c::nbatch_V2 + 4;
|
|
||||||
|
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
||||||
|
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
||||||
|
|
||||||
const int k_VKQ_0 = kb0 * c::nbatch_fa;
|
const int k_VKQ_0 = kb0 * c::nbatch_fa;
|
||||||
tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
|
tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
|
||||||
|
@ -275,12 +449,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||||
tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
|
tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
|
||||||
|
|
||||||
if constexpr (nstages > 1) {
|
if constexpr (nstages > 1) {
|
||||||
static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
|
static_assert(!mla, "multi-stage loading not implemented for MLA");
|
||||||
|
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
|
||||||
constexpr bool use_cp_async = true;
|
constexpr bool use_cp_async = true;
|
||||||
cp_async_wait_all();
|
cp_async_wait_all();
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
||||||
(V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V);
|
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
|
||||||
} else {
|
} else {
|
||||||
constexpr bool use_cp_async = nstages == 1;
|
constexpr bool use_cp_async = nstages == 1;
|
||||||
if (ncols2 > 1 || mask_h2) {
|
if (ncols2 > 1 || mask_h2) {
|
||||||
|
@ -289,8 +464,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) {
|
for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
|
||||||
const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2;
|
const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
|
||||||
const int k0_diff = k0_stop - k0_start;
|
const int k0_diff = k0_stop - k0_start;
|
||||||
|
|
||||||
if (nstages <= 1) {
|
if (nstages <= 1) {
|
||||||
|
@ -537,16 +712,21 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||||
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
|
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
|
||||||
}
|
}
|
||||||
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
||||||
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K);
|
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) {
|
|
||||||
const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV;
|
|
||||||
const int i0_diff = i0_stop - i0_start;
|
|
||||||
|
|
||||||
if (nstages <= 1) {
|
// For MLA K and V have the same data.
|
||||||
|
// Therefore, iterate over V in reverse and re-use the data if possible.
|
||||||
|
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
|
||||||
|
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
|
||||||
|
const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
|
||||||
|
const int i0_diff = i0_stop - i0_start;
|
||||||
|
|
||||||
|
if (nstages <= 1 && i0_start < reusable_cutoff) {
|
||||||
constexpr bool use_cp_async = nstages == 1;
|
constexpr bool use_cp_async = nstages == 1;
|
||||||
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
||||||
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
||||||
|
@ -555,6 +735,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
|
||||||
|
|
||||||
// Calculate VKQ tile:
|
// Calculate VKQ tile:
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -565,7 +746,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||||
const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
|
const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
|
||||||
|
|
||||||
tile_A A;
|
tile_A A;
|
||||||
load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
||||||
if (ntiles == 1) {
|
if (ntiles == 1) {
|
||||||
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
|
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
|
||||||
} else {
|
} else {
|
||||||
|
@ -596,7 +777,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||||
#endif // NEW_MMA_AVAILABLE
|
#endif // NEW_MMA_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
|
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
|
||||||
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||||
const float2 * const __restrict__ Q_f2,
|
const float2 * const __restrict__ Q_f2,
|
||||||
const half2 * const __restrict__ K_h2,
|
const half2 * const __restrict__ K_h2,
|
||||||
|
@ -632,13 +813,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||||
constexpr int cols_per_warp = ntiles * tile_B::I;
|
constexpr int cols_per_warp = ntiles * tile_B::I;
|
||||||
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
||||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||||
|
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
|
||||||
|
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
|
||||||
|
|
||||||
static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
|
static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
|
||||||
|
|
||||||
constexpr int stride_tile_Q = DKQ/2 + 4;
|
constexpr int stride_tile_Q = DKQ/2 + 4;
|
||||||
constexpr int stride_tile_K = c::nbatch_K2 + 4;
|
constexpr int stride_tile_K = nbatch_K2 + 4;
|
||||||
constexpr int stride_tile_V = c::nbatch_V2 + 4;
|
|
||||||
|
|
||||||
|
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
||||||
|
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
||||||
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
|
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
|
||||||
|
|
||||||
extern __shared__ half2 tile_Q[];
|
extern __shared__ half2 tile_Q[];
|
||||||
|
@ -726,26 +910,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||||
|
|
||||||
// Preload mask and K data for first iteration when using cp_async with multiple stages:
|
// Preload mask and K data for first iteration when using cp_async with multiple stages:
|
||||||
if constexpr (nstages > 1) {
|
if constexpr (nstages > 1) {
|
||||||
static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
|
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
|
||||||
constexpr bool use_cp_async = true;
|
constexpr bool use_cp_async = true;
|
||||||
if (ncols2 > 1 || mask_h2) {
|
if (ncols2 > 1 || mask_h2) {
|
||||||
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
|
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
|
||||||
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
|
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
|
||||||
}
|
}
|
||||||
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
||||||
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K);
|
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Iterate over ne11 == previous tokens:
|
// Iterate over ne11 == previous tokens:
|
||||||
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
|
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
|
||||||
constexpr bool last_iter = false;
|
constexpr bool last_iter = false;
|
||||||
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
|
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
||||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||||
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
||||||
}
|
}
|
||||||
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
||||||
constexpr bool last_iter = true;
|
constexpr bool last_iter = true;
|
||||||
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
|
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
||||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||||
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
||||||
}
|
}
|
||||||
|
@ -774,7 +958,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||||
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
||||||
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
|
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
|
||||||
|
|
||||||
constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4;
|
constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols);
|
||||||
constexpr int tile_stride = nbatch_combine + 4;
|
constexpr int tile_stride = nbatch_combine + 4;
|
||||||
static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
|
static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
|
||||||
|
|
||||||
|
@ -1012,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||||
#endif // NEW_MMA_AVAILABLE
|
#endif // NEW_MMA_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap>
|
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
|
||||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||||
static __global__ void flash_attn_ext_f16(
|
static __global__ void flash_attn_ext_f16(
|
||||||
const char * __restrict__ Q,
|
const char * __restrict__ Q,
|
||||||
|
@ -1057,6 +1241,14 @@ static __global__ void flash_attn_ext_f16(
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||||
|
if (ncols1*ncols2 > 32) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||||
|
|
||||||
|
static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
|
||||||
|
|
||||||
typedef fattn_mma_f16_config<DKQ, DV> c;
|
typedef fattn_mma_f16_config<DKQ, DV> c;
|
||||||
|
|
||||||
|
@ -1067,9 +1259,10 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const int stride_Q1 = nb01 / sizeof(float2);
|
const int stride_Q1 = nb01 / sizeof(float2);
|
||||||
const int stride_Q2 = nb02 / sizeof(float2);
|
const int stride_Q2 = nb02 / sizeof(float2);
|
||||||
const int stride_K = nb11 / sizeof(half2);
|
const int stride_K = nb11 / sizeof(half2);
|
||||||
const int stride_V = nb21 / sizeof(half2);
|
|
||||||
const int stride_mask = nb31 / sizeof(half2);
|
const int stride_mask = nb31 / sizeof(half2);
|
||||||
|
|
||||||
|
const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
|
||||||
|
|
||||||
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
||||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||||
|
|
||||||
|
@ -1092,10 +1285,11 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
||||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||||
const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
|
||||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||||
|
|
||||||
|
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||||
|
|
||||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||||
|
|
||||||
const int kb0_start_kernel = kb0_start * kb_niter;
|
const int kb0_start_kernel = kb0_start * kb_niter;
|
||||||
|
@ -1104,12 +1298,12 @@ static __global__ void flash_attn_ext_f16(
|
||||||
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
||||||
if (kb0_start == 0) {
|
if (kb0_start == 0) {
|
||||||
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
||||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||||
} else {
|
} else {
|
||||||
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
||||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||||
}
|
}
|
||||||
|
@ -1130,10 +1324,11 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
||||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||||
const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape
|
|
||||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||||
|
|
||||||
|
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||||
|
|
||||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||||
|
|
||||||
const int kb0_start_kernel = kb0_start * kb_niter;
|
const int kb0_start_kernel = kb0_start * kb_niter;
|
||||||
|
@ -1141,7 +1336,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
||||||
constexpr bool needs_fixup = false;
|
constexpr bool needs_fixup = false;
|
||||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||||
#else
|
#else
|
||||||
|
@ -1167,10 +1362,6 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||||
|
|
||||||
typedef fattn_mma_f16_config<DKQ, DV> c;
|
typedef fattn_mma_f16_config<DKQ, DV> c;
|
||||||
|
|
||||||
constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2;
|
|
||||||
constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2;
|
|
||||||
constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine;
|
|
||||||
|
|
||||||
const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
|
const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
|
||||||
|
|
||||||
constexpr int ncols = ncols1 * ncols2;
|
constexpr int ncols = ncols1 * ncols2;
|
||||||
|
@ -1180,15 +1371,21 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||||
constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
|
constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
|
||||||
constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
|
constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
|
||||||
|
|
||||||
|
constexpr bool mla = DKQ == 576;
|
||||||
|
|
||||||
|
const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols);
|
||||||
|
const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols);
|
||||||
|
const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols);
|
||||||
|
|
||||||
static_assert(DKQ % tile_B::J == 0, "bad DKQ");
|
static_assert(DKQ % tile_B::J == 0, "bad DKQ");
|
||||||
static_assert(DV % tile_A::J == 0, "bad DV");
|
static_assert(DV % tile_A::J == 0, "bad DV");
|
||||||
static_assert(ncols % cols_per_warp == 0, "bad ncols");
|
static_assert(ncols % cols_per_warp == 0, "bad ncols");
|
||||||
|
|
||||||
const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2);
|
const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
|
||||||
const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2);
|
const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
|
||||||
const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
|
const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
|
||||||
const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
|
const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
|
||||||
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
|
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
|
||||||
|
|
||||||
const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
|
const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
|
||||||
|
|
||||||
|
@ -1202,7 +1399,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||||
fattn_kernel_t fattn_kernel;
|
fattn_kernel_t fattn_kernel;
|
||||||
if (logit_softcap == 0.0f) {
|
if (logit_softcap == 0.0f) {
|
||||||
constexpr bool use_logit_softcap = false;
|
constexpr bool use_logit_softcap = false;
|
||||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
|
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
|
||||||
|
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||||
|
@ -1213,7 +1410,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||||
} else {
|
} else {
|
||||||
constexpr bool use_logit_softcap = true;
|
constexpr bool use_logit_softcap = true;
|
||||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
|
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
|
||||||
|
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
|
|
||||||
template <int DKQ, int DV, int ncols2>
|
template <int DKQ, int DV, int ncols2>
|
||||||
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
|
||||||
if constexpr (ncols2 <= 8) {
|
if constexpr (ncols2 <= 8) {
|
||||||
|
@ -24,7 +25,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] <= 32/ncols2) {
|
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) {
|
||||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
|
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -3227,7 +3227,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
#endif // FLASH_ATTN_AVAILABLE
|
#endif // FLASH_ATTN_AVAILABLE
|
||||||
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
||||||
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
|
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
|
||||||
if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) {
|
if (!new_mma_available(cc)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
|
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
|
||||||
|
|
|
@ -122,6 +122,7 @@ void ggml_cuda_mul_mat_q(
|
||||||
const int64_t s13 = src1->nb[3] / ts_src1;
|
const int64_t s13 = src1->nb[3] / ts_src1;
|
||||||
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
|
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
|
||||||
ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
|
ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
|
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
|
||||||
|
@ -205,6 +206,7 @@ void ggml_cuda_mul_mat_q(
|
||||||
const int64_t s13 = src1->nb[2] / ts_src1;
|
const int64_t s13 = src1->nb[2] / ts_src1;
|
||||||
quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
|
quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
|
||||||
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
|
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
|
||||||
|
|
|
@ -56,13 +56,13 @@ static __global__ void quantize_mmq_q8_1(
|
||||||
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
|
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
|
||||||
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
|
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
|
||||||
|
|
||||||
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
|
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4;
|
||||||
|
|
||||||
if (i0 >= ne0) {
|
if (i0 >= ne0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t i1 = blockIdx.y;
|
const int64_t i1 = blockIdx.x;
|
||||||
const int64_t i2 = blockIdx.z % ne2;
|
const int64_t i2 = blockIdx.z % ne2;
|
||||||
const int64_t i3 = blockIdx.z / ne2;
|
const int64_t i3 = blockIdx.z / ne2;
|
||||||
|
|
||||||
|
@ -75,8 +75,8 @@ static __global__ void quantize_mmq_q8_1(
|
||||||
|
|
||||||
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
|
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
|
||||||
|
|
||||||
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
|
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel
|
||||||
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.y; // block index in channel
|
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel
|
||||||
const int64_t iqs = i0 % (4*QK8_1); // quant index in block
|
const int64_t iqs = i0 % (4*QK8_1); // quant index in block
|
||||||
|
|
||||||
// Load 4 floats per thread and calculate max. abs. value between them:
|
// Load 4 floats per thread and calculate max. abs. value between them:
|
||||||
|
@ -166,8 +166,9 @@ void quantize_mmq_q8_1_cuda(
|
||||||
GGML_ASSERT(ne00 % 4 == 0);
|
GGML_ASSERT(ne00 % 4 == 0);
|
||||||
GGML_ASSERT(ne0 % (4*QK8_1) == 0);
|
GGML_ASSERT(ne0 % (4*QK8_1) == 0);
|
||||||
|
|
||||||
const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
|
// ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid:
|
||||||
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
|
const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
|
||||||
|
const dim3 num_blocks(ne1, block_num_y, ne2*ne3);
|
||||||
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
|
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
|
||||||
switch (mmq_get_q8_1_ds_layout(type_src0)) {
|
switch (mmq_get_q8_1_ds_layout(type_src0)) {
|
||||||
case MMQ_Q8_1_DS_LAYOUT_D4:
|
case MMQ_Q8_1_DS_LAYOUT_D4:
|
||||||
|
|
|
@ -56,6 +56,28 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> {
|
||||||
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <> struct block_q_t<GGML_TYPE_Q4_K> {
|
||||||
|
struct traits {
|
||||||
|
static constexpr uint32_t qk = QK_K;
|
||||||
|
static constexpr uint32_t qi = QI4_K;
|
||||||
|
static constexpr uint32_t qr = QR4_K;
|
||||||
|
static constexpr uint32_t vdr_mmvq = 2;
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
|
||||||
|
|
||||||
|
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
|
||||||
|
auto nblocks = (nrows * (ncols / traits::qk));
|
||||||
|
return (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2));
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
||||||
|
|
||||||
|
constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; }
|
||||||
|
|
||||||
|
constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; }
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace ggml_sycl_reordered
|
} // namespace ggml_sycl_reordered
|
||||||
|
|
||||||
#endif // GGML_SYCL_QUANTS_HPP
|
#endif // GGML_SYCL_QUANTS_HPP
|
||||||
|
|
|
@ -304,6 +304,9 @@ struct vk_device_struct {
|
||||||
bool coopmat_acc_f32_support {};
|
bool coopmat_acc_f32_support {};
|
||||||
bool coopmat_acc_f16_support {};
|
bool coopmat_acc_f16_support {};
|
||||||
bool coopmat_bf16_support {};
|
bool coopmat_bf16_support {};
|
||||||
|
bool coopmat_support_16x16x16_f16acc {};
|
||||||
|
bool coopmat_support_16x16x16_f32acc {};
|
||||||
|
bool coopmat1_fa_support {};
|
||||||
uint32_t coopmat_m;
|
uint32_t coopmat_m;
|
||||||
uint32_t coopmat_n;
|
uint32_t coopmat_n;
|
||||||
uint32_t coopmat_k;
|
uint32_t coopmat_k;
|
||||||
|
@ -426,6 +429,13 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
|
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
|
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||||
|
|
||||||
|
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||||
|
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||||
|
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||||
|
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||||
|
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||||
|
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||||
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
|
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
|
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
|
||||||
|
@ -1604,19 +1614,36 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum FaCodePath {
|
||||||
|
FA_SCALAR,
|
||||||
|
FA_COOPMAT1,
|
||||||
|
FA_COOPMAT2,
|
||||||
|
};
|
||||||
|
|
||||||
// number of rows/cols for flash attention shader
|
// number of rows/cols for flash attention shader
|
||||||
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
||||||
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
||||||
static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
|
static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
|
||||||
|
|
||||||
static uint32_t get_fa_num_small_rows(bool scalar) {
|
// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
|
||||||
return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows;
|
// 128 threads split into four subgroups, each subgroup does 1/4
|
||||||
|
// of the Bc dimension.
|
||||||
|
static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16;
|
||||||
|
static constexpr uint32_t scalar_flash_attention_Bc = 64;
|
||||||
|
static constexpr uint32_t scalar_flash_attention_workgroup_size = 128;
|
||||||
|
|
||||||
|
static uint32_t get_fa_num_small_rows(FaCodePath path) {
|
||||||
|
if (path == FA_COOPMAT2) {
|
||||||
|
return flash_attention_num_small_rows;
|
||||||
|
} else {
|
||||||
|
return scalar_flash_attention_num_small_rows;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
|
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
|
||||||
GGML_UNUSED(clamp);
|
GGML_UNUSED(clamp);
|
||||||
|
|
||||||
if (scalar) {
|
if (path == FA_SCALAR) {
|
||||||
if (small_rows) {
|
if (small_rows) {
|
||||||
return {scalar_flash_attention_num_small_rows, 64};
|
return {scalar_flash_attention_num_small_rows, 64};
|
||||||
} else {
|
} else {
|
||||||
|
@ -1624,9 +1651,17 @@ static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t cl
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (path == FA_COOPMAT1) {
|
||||||
|
if (small_rows) {
|
||||||
|
return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc};
|
||||||
|
} else {
|
||||||
|
return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// small rows, large cols
|
// small rows, large cols
|
||||||
if (small_rows) {
|
if (small_rows) {
|
||||||
return {get_fa_num_small_rows(scalar), 32};
|
return {get_fa_num_small_rows(FA_COOPMAT2), 32};
|
||||||
}
|
}
|
||||||
|
|
||||||
// small cols to reduce register count
|
// small cols to reduce register count
|
||||||
|
@ -1923,17 +1958,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
|
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
|
||||||
};
|
};
|
||||||
|
|
||||||
auto const &fa_wg_denoms = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
||||||
return {fa_rows_cols(scalar, D, clamp, type, small_rows)[0], 1, 1};
|
return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1};
|
||||||
};
|
};
|
||||||
|
|
||||||
auto const &fa_spec_constants = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
||||||
// For large number of rows, 128 invocations seems to work best.
|
// For large number of rows, 128 invocations seems to work best.
|
||||||
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
||||||
// can't use 256 for D==80.
|
// can't use 256 for D==80.
|
||||||
// For scalar, use 128 (arbitrary)
|
// For scalar, use 128 (arbitrary)
|
||||||
uint32_t wg_size = scalar ? 128 : ((small_rows && (D % 32) == 0) ? 256 : 128);
|
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
|
||||||
auto rows_cols = fa_rows_cols(scalar, D, clamp, type, small_rows);
|
? scalar_flash_attention_workgroup_size
|
||||||
|
: ((small_rows && (D % 32) == 0) ? 256 : 128);
|
||||||
|
auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows);
|
||||||
|
|
||||||
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
||||||
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
||||||
|
@ -1945,36 +1982,43 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
|
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
|
||||||
};
|
};
|
||||||
|
|
||||||
#define CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, D) \
|
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
|
|
||||||
#define CREATE_FA(TYPE, NAMELC, SCALAR, SUFFIX) \
|
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
||||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 64) \
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
|
||||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 80) \
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
|
||||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 96) \
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
|
||||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 112) \
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
|
||||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 128) \
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
|
||||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 256)
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256)
|
||||||
|
|
||||||
CREATE_FA(GGML_TYPE_F16, f16, true, )
|
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
||||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, true, )
|
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, true, )
|
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
||||||
|
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||||
|
if (device->coopmat1_fa_support) {
|
||||||
|
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
|
||||||
|
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
|
||||||
|
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
|
||||||
|
}
|
||||||
|
#endif
|
||||||
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||||
if (device->coopmat2) {
|
if (device->coopmat2) {
|
||||||
CREATE_FA(GGML_TYPE_F16, f16, false, _cm2)
|
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
|
||||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, false, _cm2)
|
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
|
||||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, false, _cm2)
|
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
|
||||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, false, _cm2)
|
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2)
|
||||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, false, _cm2)
|
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2)
|
||||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, false, _cm2)
|
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2)
|
||||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, false, _cm2)
|
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
#undef CREATE_FA2
|
#undef CREATE_FA2
|
||||||
|
@ -2057,17 +2101,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||||
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
|
||||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
|
||||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
|
||||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
|
||||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
|
||||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
|
||||||
|
|
||||||
// Create 2 variants, {f16,f32} accumulator
|
// Create 2 variants, {f16,f32} accumulator
|
||||||
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||||
|
@ -3033,6 +3077,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
|
|
||||||
#if defined(VK_KHR_cooperative_matrix)
|
#if defined(VK_KHR_cooperative_matrix)
|
||||||
device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
|
device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
|
||||||
|
|
||||||
|
// coopmat1 fa shader currently assumes 32 invocations per subgroup
|
||||||
|
device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support &&
|
||||||
|
device->subgroup_size_control && device->subgroup_min_size <= 32 &&
|
||||||
|
device->subgroup_max_size >= 32;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (coopmat2_support) {
|
if (coopmat2_support) {
|
||||||
|
@ -3167,6 +3216,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
// Only enable if shape is identical
|
// Only enable if shape is identical
|
||||||
device->coopmat_acc_f32_support = true;
|
device->coopmat_acc_f32_support = true;
|
||||||
}
|
}
|
||||||
|
if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
|
||||||
|
device->coopmat_support_16x16x16_f32acc = true;
|
||||||
|
}
|
||||||
} else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
|
} else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
|
||||||
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
|
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
|
||||||
// coopmat sizes not set yet
|
// coopmat sizes not set yet
|
||||||
|
@ -3179,6 +3231,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
// Only enable if shape is identical
|
// Only enable if shape is identical
|
||||||
device->coopmat_acc_f16_support = true;
|
device->coopmat_acc_f16_support = true;
|
||||||
}
|
}
|
||||||
|
if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
|
||||||
|
device->coopmat_support_16x16x16_f16acc = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 &&
|
} else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 &&
|
||||||
(vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 &&
|
(vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 &&
|
||||||
|
@ -5712,6 +5767,36 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) {
|
||||||
|
// Needs to be kept up to date on shader changes
|
||||||
|
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
||||||
|
const uint32_t Br = scalar_flash_attention_num_large_rows;
|
||||||
|
const uint32_t Bc = scalar_flash_attention_Bc;
|
||||||
|
|
||||||
|
const uint32_t acctype = f32acc ? 4 : 2;
|
||||||
|
const uint32_t f16vec4 = 8;
|
||||||
|
|
||||||
|
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||||
|
const uint32_t tmpshv4 = wg_size * 4 * acctype;
|
||||||
|
|
||||||
|
const uint32_t Qf = Br * (D / 4 + 2) * f16vec4;
|
||||||
|
|
||||||
|
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
|
||||||
|
const uint32_t sfsh = Bc * sfshstride * acctype;
|
||||||
|
|
||||||
|
const uint32_t kshstride = D / 4 + 2;
|
||||||
|
const uint32_t ksh = Bc * kshstride * f16vec4;
|
||||||
|
|
||||||
|
const uint32_t slope = Br * sizeof(float);
|
||||||
|
|
||||||
|
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
|
||||||
|
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||||
|
|
||||||
|
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
|
||||||
|
|
||||||
|
return supported;
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
|
static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
|
||||||
VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
|
VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
|
||||||
std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
|
std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
|
||||||
|
@ -5762,7 +5847,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
assert(q->type == GGML_TYPE_F32);
|
assert(q->type == GGML_TYPE_F32);
|
||||||
assert(k->type == v->type);
|
assert(k->type == v->type);
|
||||||
|
|
||||||
bool scalar = !ctx->device->coopmat2;
|
FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 :
|
||||||
|
ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
|
||||||
|
|
||||||
|
if (path == FA_COOPMAT1) {
|
||||||
|
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
|
||||||
|
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
|
||||||
|
|
||||||
|
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32);
|
||||||
|
|
||||||
|
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
|
||||||
|
path = FA_SCALAR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t gqa_ratio = 1;
|
uint32_t gqa_ratio = 1;
|
||||||
uint32_t qk_ratio = neq2 / nek2;
|
uint32_t qk_ratio = neq2 / nek2;
|
||||||
|
@ -5770,9 +5867,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
uint32_t workgroups_y = (uint32_t)neq2;
|
uint32_t workgroups_y = (uint32_t)neq2;
|
||||||
uint32_t workgroups_z = (uint32_t)neq3;
|
uint32_t workgroups_z = (uint32_t)neq3;
|
||||||
|
|
||||||
// For scalar FA, we can use the "large" size to accommodate qga.
|
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
|
||||||
// For coopmat FA, we always use the small size (which is still pretty large for gqa).
|
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
|
||||||
const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false);
|
uint32_t max_gqa;
|
||||||
|
switch (path) {
|
||||||
|
case FA_SCALAR:
|
||||||
|
case FA_COOPMAT1:
|
||||||
|
// We may switch from coopmat1 to scalar, so use the scalar limit for both
|
||||||
|
max_gqa = scalar_flash_attention_num_large_rows;
|
||||||
|
break;
|
||||||
|
case FA_COOPMAT2:
|
||||||
|
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(0);
|
||||||
|
}
|
||||||
|
|
||||||
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
||||||
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
|
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
|
||||||
|
@ -5785,11 +5894,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
}
|
}
|
||||||
|
|
||||||
vk_pipeline *pipelines;
|
vk_pipeline *pipelines;
|
||||||
// XXX TODO other backends may be changing accumulator precision to default to f32 soon
|
bool small_rows = N <= get_fa_num_small_rows(path);
|
||||||
bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32;
|
|
||||||
bool small_rows = N <= get_fa_num_small_rows(scalar);
|
|
||||||
|
|
||||||
if (scalar) {
|
if (small_rows && path == FA_COOPMAT1) {
|
||||||
|
path = FA_SCALAR;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
||||||
|
|
||||||
|
switch (path) {
|
||||||
|
case FA_SCALAR:
|
||||||
switch (D) {
|
switch (D) {
|
||||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
||||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
||||||
|
@ -5801,7 +5915,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
GGML_ASSERT(!"unsupported D value");
|
GGML_ASSERT(!"unsupported D value");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
} else {
|
break;
|
||||||
|
case FA_COOPMAT1:
|
||||||
|
switch (D) {
|
||||||
|
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
|
||||||
|
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
|
||||||
|
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
|
||||||
|
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
|
||||||
|
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
|
||||||
|
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(!"unsupported D value");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case FA_COOPMAT2:
|
||||||
switch (D) {
|
switch (D) {
|
||||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
|
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
|
||||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
|
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
|
||||||
|
@ -5813,6 +5941,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
GGML_ASSERT(!"unsupported D value");
|
GGML_ASSERT(!"unsupported D value");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(0);
|
||||||
}
|
}
|
||||||
assert(pipelines);
|
assert(pipelines);
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
|
|
||||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
|
||||||
layout (constant_id = 1) const uint32_t Br = 1;
|
layout (constant_id = 1) const uint32_t Br = 1;
|
||||||
layout (constant_id = 2) const uint32_t Bc = 32;
|
layout (constant_id = 2) const uint32_t Bc = 32;
|
||||||
layout (constant_id = 3) const uint32_t D = 32;
|
layout (constant_id = 3) const uint32_t D = 32;
|
||||||
|
@ -19,7 +20,7 @@ layout (constant_id = 3) const uint32_t D = 32;
|
||||||
layout (constant_id = 5) const uint32_t D_split = 16;
|
layout (constant_id = 5) const uint32_t D_split = 16;
|
||||||
const uint32_t D_per_thread = D / D_split;
|
const uint32_t D_per_thread = D / D_split;
|
||||||
|
|
||||||
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split;
|
const uint32_t cols_per_iter = WorkGroupSize / D_split;
|
||||||
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
layout (push_constant) uniform parameter {
|
||||||
|
@ -134,8 +135,8 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
|
||||||
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
||||||
}
|
}
|
||||||
|
|
||||||
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
shared FLOAT_TYPE tmpsh[WorkGroupSize];
|
||||||
shared vec4 tmpshv4[gl_WorkGroupSize.x];
|
shared vec4 tmpshv4[WorkGroupSize];
|
||||||
|
|
||||||
shared float masksh[Bc][Br];
|
shared float masksh[Bc][Br];
|
||||||
shared vec4 Qf[Br][D / 4];
|
shared vec4 Qf[Br][D / 4];
|
||||||
|
|
506
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
Normal file
506
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
Normal file
|
@ -0,0 +1,506 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
#extension GL_EXT_shader_16bit_storage : require
|
||||||
|
|
||||||
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||||
|
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||||
|
|
||||||
|
#extension GL_KHR_shader_subgroup_basic : enable
|
||||||
|
#extension GL_KHR_memory_scope_semantics : enable
|
||||||
|
#extension GL_KHR_cooperative_matrix : enable
|
||||||
|
|
||||||
|
#include "types.comp"
|
||||||
|
|
||||||
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (constant_id = 1) const uint32_t Br = 1;
|
||||||
|
layout (constant_id = 2) const uint32_t Bc = 32;
|
||||||
|
layout (constant_id = 3) const uint32_t D = 32;
|
||||||
|
|
||||||
|
layout (constant_id = 5) const uint32_t D_split = 16;
|
||||||
|
|
||||||
|
const uint32_t D_per_thread = D / D_split;
|
||||||
|
const uint32_t row_split = 4;
|
||||||
|
const uint32_t rows_per_thread = Br / row_split;
|
||||||
|
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
|
||||||
|
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter {
|
||||||
|
uint32_t N;
|
||||||
|
uint32_t KV;
|
||||||
|
|
||||||
|
uint32_t ne1;
|
||||||
|
uint32_t ne2;
|
||||||
|
uint32_t ne3;
|
||||||
|
|
||||||
|
uint32_t neq2;
|
||||||
|
uint32_t neq3;
|
||||||
|
uint32_t nek2;
|
||||||
|
uint32_t nek3;
|
||||||
|
uint32_t nev2;
|
||||||
|
uint32_t nev3;
|
||||||
|
uint32_t nem1;
|
||||||
|
|
||||||
|
uint32_t nb01;
|
||||||
|
uint32_t nb02;
|
||||||
|
uint32_t nb03;
|
||||||
|
uint32_t nb11;
|
||||||
|
uint32_t nb12;
|
||||||
|
uint32_t nb13;
|
||||||
|
uint32_t nb21;
|
||||||
|
uint32_t nb22;
|
||||||
|
uint32_t nb23;
|
||||||
|
uint32_t nb31;
|
||||||
|
|
||||||
|
float scale;
|
||||||
|
float max_bias;
|
||||||
|
float logit_softcap;
|
||||||
|
|
||||||
|
uint32_t mask;
|
||||||
|
uint32_t n_head_log2;
|
||||||
|
float m0;
|
||||||
|
float m1;
|
||||||
|
|
||||||
|
uint32_t gqa_ratio;
|
||||||
|
uint32_t split_kv;
|
||||||
|
uint32_t k_num;
|
||||||
|
} p;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer Q {float data_q[];};
|
||||||
|
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
|
||||||
|
layout (binding = 1) readonly buffer K {float16_t data_k[];};
|
||||||
|
layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
|
||||||
|
layout (binding = 2) readonly buffer V {float16_t data_v[];};
|
||||||
|
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
|
||||||
|
layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
||||||
|
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
||||||
|
|
||||||
|
#if defined(A_TYPE_PACKED16)
|
||||||
|
#define BINDING_IDX_K 0
|
||||||
|
#define BINDING_IDX_V 1
|
||||||
|
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_Q4_0)
|
||||||
|
#define BLOCK_BYTE_SIZE 18
|
||||||
|
|
||||||
|
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||||
|
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||||
|
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||||
|
uint shift = (iqs & 0x10) >> 2;
|
||||||
|
vui_lo >>= shift;
|
||||||
|
vui_hi >>= shift;
|
||||||
|
|
||||||
|
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_Q8_0)
|
||||||
|
#define BLOCK_BYTE_SIZE 34
|
||||||
|
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||||
|
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||||
|
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||||
|
|
||||||
|
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||||
|
|
||||||
|
// Store the output when doing grouped query attention.
|
||||||
|
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||||
|
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||||
|
{
|
||||||
|
uint32_t offset = (iq2 + r) * D + c;
|
||||||
|
data_o[o_offset + offset] = D_TYPE(elem);
|
||||||
|
return elem;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store column zero. This is used to save per-row m and L values for split_k.
|
||||||
|
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||||
|
{
|
||||||
|
if (r < N && c == 0) {
|
||||||
|
uint32_t offset = iq2 + r;
|
||||||
|
data_o[o_offset + offset] = D_TYPE(elem);
|
||||||
|
}
|
||||||
|
return elem;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the slope matrix, indexed by Q's dimension 2.
|
||||||
|
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
||||||
|
{
|
||||||
|
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
||||||
|
|
||||||
|
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
||||||
|
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
||||||
|
|
||||||
|
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
|
||||||
|
const uint32_t MatBr = 16;
|
||||||
|
const uint32_t MatBc = 16;
|
||||||
|
|
||||||
|
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
||||||
|
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
|
||||||
|
|
||||||
|
const uint32_t qstride = D / 4 + 2; // in units of f16vec4
|
||||||
|
shared f16vec4 Qf[Br * qstride];
|
||||||
|
|
||||||
|
// Avoid padding for D==256 to make it fit in 48KB shmem.
|
||||||
|
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
|
||||||
|
shared ACC_TYPE sfsh[Bc * sfshstride];
|
||||||
|
|
||||||
|
const uint32_t kshstride = D / 4 + 2; // in units of f16vec4
|
||||||
|
shared f16vec4 ksh[Bc * kshstride];
|
||||||
|
|
||||||
|
shared float slope[Br];
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||||
|
init_iq_shmem(gl_WorkGroupSize);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
const uint32_t tid = gl_LocalInvocationIndex;
|
||||||
|
const uint32_t N = p.N;
|
||||||
|
const uint32_t KV = p.KV;
|
||||||
|
|
||||||
|
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
|
||||||
|
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
|
||||||
|
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
|
||||||
|
const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
|
||||||
|
|
||||||
|
#define tile_row(r) (row_tid * rows_per_thread + (r))
|
||||||
|
|
||||||
|
uint32_t i = gl_WorkGroupID.x;
|
||||||
|
uint32_t split_k_index = 0;
|
||||||
|
|
||||||
|
if (p.k_num > 1) {
|
||||||
|
i = 0;
|
||||||
|
split_k_index = gl_WorkGroupID.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t Tr = CEIL_DIV(N, Br);
|
||||||
|
|
||||||
|
const uint32_t start_j = split_k_index * p.split_kv / Bc;
|
||||||
|
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
||||||
|
|
||||||
|
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
||||||
|
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
||||||
|
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
||||||
|
const uint32_t iq3 = gl_WorkGroupID.z;
|
||||||
|
|
||||||
|
// broadcast factors
|
||||||
|
const uint32_t rk2 = p.neq2/p.nek2;
|
||||||
|
const uint32_t rk3 = p.neq3/p.nek3;
|
||||||
|
|
||||||
|
const uint32_t rv2 = p.neq2/p.nev2;
|
||||||
|
const uint32_t rv3 = p.neq3/p.nev3;
|
||||||
|
|
||||||
|
// k indices
|
||||||
|
const uint32_t ik3 = iq3 / rk3;
|
||||||
|
const uint32_t ik2 = iq2 / rk2;
|
||||||
|
|
||||||
|
// v indices
|
||||||
|
const uint32_t iv3 = iq3 / rv3;
|
||||||
|
const uint32_t iv2 = iq2 / rv2;
|
||||||
|
|
||||||
|
// nb?1 are already divided by the type size and are in units of elements.
|
||||||
|
// When using grouped query attention, Q is indexed by iq2, so the stride
|
||||||
|
// should be nb02 (which is in bytes).
|
||||||
|
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
||||||
|
uint32_t k_stride = p.nb11;
|
||||||
|
uint32_t v_stride = p.nb21;
|
||||||
|
// When using grouped query attention, all rows use the same mask (stride 0).
|
||||||
|
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
||||||
|
// that prevents the compiler from folding the "&" through the select
|
||||||
|
// and breaking the alignment detection.
|
||||||
|
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
||||||
|
|
||||||
|
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||||
|
|
||||||
|
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
|
||||||
|
uint32_t d = (idx + tid) % (D / 4);
|
||||||
|
uint32_t r = (idx + tid) / (D / 4);
|
||||||
|
if (r < Br && d < D / 4 &&
|
||||||
|
i * Br + r < N) {
|
||||||
|
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4];
|
||||||
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
|
Of[r][d] = ACC_TYPEV4(0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float Lf[rows_per_thread], Mf[rows_per_thread];
|
||||||
|
|
||||||
|
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
|
||||||
|
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
|
||||||
|
|
||||||
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
|
Lf[r] = 0;
|
||||||
|
Mf[r] = NEG_FLT_MAX_OVER_2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ALiBi
|
||||||
|
if (p.max_bias > 0.0f) {
|
||||||
|
if (tid < Br) {
|
||||||
|
uint r = tid;
|
||||||
|
slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
} else {
|
||||||
|
if (tid < Br) {
|
||||||
|
uint r = tid;
|
||||||
|
slope[r] = 1.0;
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
#if BLOCK_SIZE > 1
|
||||||
|
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
|
||||||
|
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
|
||||||
|
#else
|
||||||
|
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
||||||
|
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
[[dont_unroll]]
|
||||||
|
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||||
|
|
||||||
|
[[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) {
|
||||||
|
uint32_t d = (idx + tid) % (D / 4);
|
||||||
|
uint32_t c = (idx + tid) / (D / 4);
|
||||||
|
if (c < Bc && d < D / 4) {
|
||||||
|
#if BLOCK_SIZE > 1
|
||||||
|
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
||||||
|
uint ib = coord / BLOCK_SIZE;
|
||||||
|
uint iqs = (coord % BLOCK_SIZE);
|
||||||
|
f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
|
||||||
|
#else
|
||||||
|
f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
ksh[c * kshstride + d] = K_Tf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
// K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br
|
||||||
|
// Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
|
||||||
|
// This is written transposed in order to allow for N being 8 if implementations need it
|
||||||
|
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
||||||
|
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
|
||||||
|
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
|
||||||
|
|
||||||
|
for (uint32_t d = 0; d < D / 16; ++d) {
|
||||||
|
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||||
|
|
||||||
|
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
|
||||||
|
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
|
||||||
|
|
||||||
|
SfMat = coopMatMulAdd(KMat, QMat, SfMat);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint coord = gl_SubgroupID * MatBc * sfshstride;
|
||||||
|
coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor);
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
if (p.logit_softcap != 0.0f) {
|
||||||
|
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||||
|
uint32_t c = (idx + tid) / Br;
|
||||||
|
uint32_t r = (idx + tid) % Br;
|
||||||
|
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
||||||
|
sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (p.mask != 0) {
|
||||||
|
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||||
|
uint32_t c = (idx + tid) % Bc;
|
||||||
|
uint32_t r = (idx + tid) / Bc;
|
||||||
|
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
||||||
|
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
float eMf[rows_per_thread];
|
||||||
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
|
float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride];
|
||||||
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||||
|
rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
|
||||||
|
}
|
||||||
|
float Moldf = Mf[r];
|
||||||
|
|
||||||
|
// M = max(rowmax, Mold)
|
||||||
|
// P = e^(S - M)
|
||||||
|
// eM = e^(Mold - M)
|
||||||
|
Mf[r] = max(rowmaxf, Moldf);
|
||||||
|
eMf[r] = exp(Moldf - Mf[r]);
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
|
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
|
Lf[r] = eMf[r]*Lf[r];
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||||
|
float Pf[rows_per_thread];
|
||||||
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
|
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
|
||||||
|
Lf[r] += Pf[r];
|
||||||
|
}
|
||||||
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
|
#if BLOCK_SIZE > 1
|
||||||
|
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||||
|
uint ib = coord / BLOCK_SIZE;
|
||||||
|
uint iqs = (coord % BLOCK_SIZE);
|
||||||
|
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||||
|
#else
|
||||||
|
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
|
||||||
|
#endif
|
||||||
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
|
Of[r][d] += float16_t(Pf[r]) * ACC_TYPEV4(Vf);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
// reduce across threads
|
||||||
|
|
||||||
|
float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];
|
||||||
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
|
FLOAT_TYPE M = Mf[r];
|
||||||
|
tmpsh[tid] = M;
|
||||||
|
// Compute max across the row
|
||||||
|
barrier();
|
||||||
|
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
|
||||||
|
M = max(M, tmpsh[tid ^ s]);
|
||||||
|
barrier();
|
||||||
|
tmpsh[tid] = M;
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
|
Moldf[r] = Mf[r];
|
||||||
|
|
||||||
|
// M = max(rowmax, Mold)
|
||||||
|
// eM = e^(Mold - M)
|
||||||
|
Mf[r] = max(rowmaxf[r], Moldf[r]);
|
||||||
|
eMf[r] = exp(Moldf[r] - Mf[r]);
|
||||||
|
|
||||||
|
Lf[r] = eMf[r]*Lf[r];
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
|
FLOAT_TYPE L = Lf[r];
|
||||||
|
tmpsh[tid] = L;
|
||||||
|
// Compute sum across the row
|
||||||
|
barrier();
|
||||||
|
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
|
||||||
|
L += tmpsh[tid ^ s];
|
||||||
|
barrier();
|
||||||
|
tmpsh[tid] = L;
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
|
|
||||||
|
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
||||||
|
tmpshv4[tid] = Of[r][d];
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
|
||||||
|
Of[r][d] += tmpshv4[tid ^ s];
|
||||||
|
barrier();
|
||||||
|
tmpshv4[tid] = Of[r][d];
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup];
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there is split_k, then the split_k resolve shader does the final
|
||||||
|
// division by L. Store the intermediate O value and per-row m and L values.
|
||||||
|
if (p.k_num > 1) {
|
||||||
|
uint32_t o_offset = D * p.ne1 * split_k_index;
|
||||||
|
|
||||||
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
|
if (tile_row(r) < N) {
|
||||||
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||||
|
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
|
||||||
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
|
if (tile_row(r) < N) {
|
||||||
|
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||||
|
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float Lfrcp[rows_per_thread];
|
||||||
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
|
Lfrcp[r] = 1.0 / Lf[r];
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
|
Of[r][d] *= float16_t(Lfrcp[r]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
||||||
|
|
||||||
|
if (p.gqa_ratio > 1) {
|
||||||
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
|
if (tile_row(r) < N) {
|
||||||
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||||
|
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
|
if (i * Br + tile_row(r) < N) {
|
||||||
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||||
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||||
|
data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -227,7 +227,7 @@ static std::mutex compile_count_mutex;
|
||||||
static std::condition_variable compile_count_cond;
|
static std::condition_variable compile_count_cond;
|
||||||
|
|
||||||
void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
|
void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
|
||||||
std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
|
std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
|
||||||
std::string out_fname = join_paths(output_dir, name + ".spv");
|
std::string out_fname = join_paths(output_dir, name + ".spv");
|
||||||
std::string in_path = join_paths(input_dir, in_fname);
|
std::string in_path = join_paths(input_dir, in_fname);
|
||||||
|
|
||||||
|
@ -438,6 +438,7 @@ void process_shaders() {
|
||||||
// flash attention
|
// flash attention
|
||||||
for (const auto& f16acc : {false, true}) {
|
for (const auto& f16acc : {false, true}) {
|
||||||
std::string acctype = f16acc ? "float16_t" : "float";
|
std::string acctype = f16acc ? "float16_t" : "float";
|
||||||
|
std::string acctypev4 = f16acc ? "f16vec4" : "vec4";
|
||||||
|
|
||||||
for (const auto& tname : type_names) {
|
for (const auto& tname : type_names) {
|
||||||
if (tname == "f32") {
|
if (tname == "f32") {
|
||||||
|
@ -454,6 +455,16 @@ void process_shaders() {
|
||||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||||
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
|
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||||
|
if (tname == "f16") {
|
||||||
|
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||||
|
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||||
|
} else if (tname == "q4_0" || tname == "q8_0") {
|
||||||
|
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||||
|
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||||
|
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
if (tname == "f16") {
|
if (tname == "f16") {
|
||||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||||
|
|
|
@ -309,10 +309,10 @@ bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector<struct
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} catch (std::length_error &) {
|
} catch (std::length_error &) {
|
||||||
fprintf(stderr, "%s: encountered length_error while reading value for key '%s'\n", __func__, key.c_str());
|
GGML_LOG_ERROR("%s: encountered length_error while reading value for key '%s'\n", __func__, key.c_str());
|
||||||
return false;
|
return false;
|
||||||
} catch (std::bad_alloc &) {
|
} catch (std::bad_alloc &) {
|
||||||
fprintf(stderr, "%s: encountered bad_alloc error while reading value for key '%s'\n", __func__, key.c_str());
|
GGML_LOG_ERROR("%s: encountered bad_alloc error while reading value for key '%s'\n", __func__, key.c_str());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
kv.emplace_back(key, value);
|
kv.emplace_back(key, value);
|
||||||
|
@ -338,14 +338,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||||
ok = ok && gr.read(magic, 4);
|
ok = ok && gr.read(magic, 4);
|
||||||
|
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
fprintf(stderr, "%s: failed to read magic\n", __func__);
|
GGML_LOG_ERROR("%s: failed to read magic\n", __func__);
|
||||||
gguf_free(ctx);
|
gguf_free(ctx);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (uint32_t i = 0; i < magic.size(); i++) {
|
for (uint32_t i = 0; i < magic.size(); i++) {
|
||||||
if (magic[i] != GGUF_MAGIC[i]) {
|
if (magic[i] != GGUF_MAGIC[i]) {
|
||||||
fprintf(stderr, "%s: invalid magic characters: '%c%c%c%c', expected 'GGUF'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
|
GGML_LOG_ERROR("%s: invalid magic characters: '%c%c%c%c', expected 'GGUF'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
|
||||||
gguf_free(ctx);
|
gguf_free(ctx);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -393,7 +393,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||||
if (ok && gr.read(n_tensors)) {
|
if (ok && gr.read(n_tensors)) {
|
||||||
static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing");
|
static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing");
|
||||||
if (n_tensors < 0 || n_tensors > int64_t(SIZE_MAX/sizeof(gguf_tensor_info))) {
|
if (n_tensors < 0 || n_tensors > int64_t(SIZE_MAX/sizeof(gguf_tensor_info))) {
|
||||||
fprintf(stderr, "%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n",
|
GGML_LOG_ERROR("%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n",
|
||||||
__func__, n_tensors, SIZE_MAX/sizeof(gguf_tensor_info));
|
__func__, n_tensors, SIZE_MAX/sizeof(gguf_tensor_info));
|
||||||
ok = false;
|
ok = false;
|
||||||
}
|
}
|
||||||
|
@ -404,7 +404,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||||
if (ok && gr.read(n_kv)) {
|
if (ok && gr.read(n_kv)) {
|
||||||
static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing");
|
static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing");
|
||||||
if (n_kv < 0 || n_kv > int64_t(SIZE_MAX/sizeof(gguf_kv))) {
|
if (n_kv < 0 || n_kv > int64_t(SIZE_MAX/sizeof(gguf_kv))) {
|
||||||
fprintf(stderr, "%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n",
|
GGML_LOG_ERROR("%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n",
|
||||||
__func__, n_kv, SIZE_MAX/sizeof(gguf_kv));
|
__func__, n_kv, SIZE_MAX/sizeof(gguf_kv));
|
||||||
ok = false;
|
ok = false;
|
||||||
}
|
}
|
||||||
|
@ -414,7 +414,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
fprintf(stderr, "%s: failed to read header\n", __func__);
|
GGML_LOG_ERROR("%s: failed to read header\n", __func__);
|
||||||
gguf_free(ctx);
|
gguf_free(ctx);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -430,15 +430,15 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||||
try {
|
try {
|
||||||
ok = ok && gr.read(key);
|
ok = ok && gr.read(key);
|
||||||
} catch (std::length_error &) {
|
} catch (std::length_error &) {
|
||||||
fprintf(stderr, "%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i);
|
GGML_LOG_ERROR("%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i);
|
||||||
ok = false;
|
ok = false;
|
||||||
} catch (std::bad_alloc &) {
|
} catch (std::bad_alloc &) {
|
||||||
fprintf(stderr, "%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i);
|
GGML_LOG_ERROR("%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i);
|
||||||
ok = false;
|
ok = false;
|
||||||
}
|
}
|
||||||
for (size_t j = 0; ok && j < ctx->kv.size(); ++j) {
|
for (size_t j = 0; ok && j < ctx->kv.size(); ++j) {
|
||||||
if (key == ctx->kv[j].key) {
|
if (key == ctx->kv[j].key) {
|
||||||
fprintf(stderr, "%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i);
|
GGML_LOG_ERROR("%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i);
|
||||||
ok = false;
|
ok = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -479,14 +479,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||||
case GGUF_TYPE_ARRAY:
|
case GGUF_TYPE_ARRAY:
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
fprintf(stderr, "%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type);
|
GGML_LOG_ERROR("%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type);
|
||||||
ok = false;
|
ok = false;
|
||||||
} break;
|
} break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
fprintf(stderr, "%s: failed to read key-value pairs\n", __func__);
|
GGML_LOG_ERROR("%s: failed to read key-value pairs\n", __func__);
|
||||||
gguf_free(ctx);
|
gguf_free(ctx);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -496,7 +496,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||||
ctx->alignment = alignment_idx == -1 ? GGUF_DEFAULT_ALIGNMENT : gguf_get_val_u32(ctx, alignment_idx);
|
ctx->alignment = alignment_idx == -1 ? GGUF_DEFAULT_ALIGNMENT : gguf_get_val_u32(ctx, alignment_idx);
|
||||||
|
|
||||||
if (ctx->alignment == 0 || (ctx->alignment & (ctx->alignment - 1)) != 0) {
|
if (ctx->alignment == 0 || (ctx->alignment & (ctx->alignment - 1)) != 0) {
|
||||||
fprintf(stderr, "%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment);
|
GGML_LOG_ERROR("%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment);
|
||||||
gguf_free(ctx);
|
gguf_free(ctx);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -512,14 +512,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||||
try {
|
try {
|
||||||
ok = ok && gr.read(name);
|
ok = ok && gr.read(name);
|
||||||
} catch (std::length_error &) {
|
} catch (std::length_error &) {
|
||||||
fprintf(stderr, "%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i);
|
GGML_LOG_ERROR("%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i);
|
||||||
ok = false;
|
ok = false;
|
||||||
} catch (std::bad_alloc &) {
|
} catch (std::bad_alloc &) {
|
||||||
fprintf(stderr, "%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i);
|
GGML_LOG_ERROR("%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i);
|
||||||
ok = false;
|
ok = false;
|
||||||
}
|
}
|
||||||
if (name.length() >= GGML_MAX_NAME) {
|
if (name.length() >= GGML_MAX_NAME) {
|
||||||
fprintf(stderr, "%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME);
|
GGML_LOG_ERROR("%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME);
|
||||||
ok = false;
|
ok = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -528,7 +528,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||||
// make sure there are no duplicate tensor names
|
// make sure there are no duplicate tensor names
|
||||||
for (int64_t j = 0; ok && j < i; ++j) {
|
for (int64_t j = 0; ok && j < i; ++j) {
|
||||||
if (strcmp(info.t.name, ctx->info[j].t.name) == 0) {
|
if (strcmp(info.t.name, ctx->info[j].t.name) == 0) {
|
||||||
fprintf(stderr, "%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i);
|
GGML_LOG_ERROR("%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i);
|
||||||
ok = false;
|
ok = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -543,7 +543,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||||
uint32_t n_dims = -1;
|
uint32_t n_dims = -1;
|
||||||
ok = ok && gr.read(n_dims);
|
ok = ok && gr.read(n_dims);
|
||||||
if (n_dims > GGML_MAX_DIMS) {
|
if (n_dims > GGML_MAX_DIMS) {
|
||||||
fprintf(stderr, "%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n",
|
GGML_LOG_ERROR("%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n",
|
||||||
__func__, info.t.name, n_dims, GGML_MAX_DIMS);
|
__func__, info.t.name, n_dims, GGML_MAX_DIMS);
|
||||||
ok = false;
|
ok = false;
|
||||||
break;
|
break;
|
||||||
|
@ -563,7 +563,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||||
|
|
||||||
// check that all ne are non-negative
|
// check that all ne are non-negative
|
||||||
if (info.t.ne[j] < 0) {
|
if (info.t.ne[j] < 0) {
|
||||||
fprintf(stderr, "%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n",
|
GGML_LOG_ERROR("%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n",
|
||||||
__func__, info.t.name, j, info.t.ne[j]);
|
__func__, info.t.name, j, info.t.ne[j]);
|
||||||
ok = false;
|
ok = false;
|
||||||
break;
|
break;
|
||||||
|
@ -575,7 +575,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||||
(INT64_MAX/info.t.ne[2] <= info.t.ne[0]*info.t.ne[1]) ||
|
(INT64_MAX/info.t.ne[2] <= info.t.ne[0]*info.t.ne[1]) ||
|
||||||
(INT64_MAX/info.t.ne[3] <= info.t.ne[0]*info.t.ne[1]*info.t.ne[2]))) {
|
(INT64_MAX/info.t.ne[3] <= info.t.ne[0]*info.t.ne[1]*info.t.ne[2]))) {
|
||||||
|
|
||||||
fprintf(stderr, "%s: total number of elements in tensor '%s' with shape "
|
GGML_LOG_ERROR("%s: total number of elements in tensor '%s' with shape "
|
||||||
"(%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") is >= %" PRIi64 "\n",
|
"(%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") is >= %" PRIi64 "\n",
|
||||||
__func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], INT64_MAX);
|
__func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], INT64_MAX);
|
||||||
ok = false;
|
ok = false;
|
||||||
|
@ -592,7 +592,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||||
|
|
||||||
// check that tensor type is within defined range
|
// check that tensor type is within defined range
|
||||||
if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) {
|
if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) {
|
||||||
fprintf(stderr, "%s: tensor '%s' has invalid ggml type %d (%s)\n",
|
GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d (%s)\n",
|
||||||
__func__, info.t.name, info.t.type, ggml_type_name(info.t.type));
|
__func__, info.t.name, info.t.type, ggml_type_name(info.t.type));
|
||||||
ok = false;
|
ok = false;
|
||||||
break;
|
break;
|
||||||
|
@ -602,7 +602,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||||
|
|
||||||
// check that row size is divisible by block size
|
// check that row size is divisible by block size
|
||||||
if (blck_size == 0 || info.t.ne[0] % blck_size != 0) {
|
if (blck_size == 0 || info.t.ne[0] % blck_size != 0) {
|
||||||
fprintf(stderr, "%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, "
|
GGML_LOG_ERROR("%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, "
|
||||||
"not a multiple of block size (%" PRId64 ")\n",
|
"not a multiple of block size (%" PRId64 ")\n",
|
||||||
__func__, info.t.name, (int) info.t.type, ggml_type_name(info.t.type), info.t.ne[0], blck_size);
|
__func__, info.t.name, (int) info.t.type, ggml_type_name(info.t.type), info.t.ne[0], blck_size);
|
||||||
ok = false;
|
ok = false;
|
||||||
|
@ -627,7 +627,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
fprintf(stderr, "%s: failed to read tensor info\n", __func__);
|
GGML_LOG_ERROR("%s: failed to read tensor info\n", __func__);
|
||||||
gguf_free(ctx);
|
gguf_free(ctx);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -635,7 +635,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||||
|
|
||||||
// we require the data section to be aligned, so take into account any padding
|
// we require the data section to be aligned, so take into account any padding
|
||||||
if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) {
|
if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) {
|
||||||
fprintf(stderr, "%s: failed to seek to beginning of data section\n", __func__);
|
GGML_LOG_ERROR("%s: failed to seek to beginning of data section\n", __func__);
|
||||||
gguf_free(ctx);
|
gguf_free(ctx);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -649,9 +649,9 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||||
for (size_t i = 0; i < ctx->info.size(); ++i) {
|
for (size_t i = 0; i < ctx->info.size(); ++i) {
|
||||||
const gguf_tensor_info & ti = ctx->info[i];
|
const gguf_tensor_info & ti = ctx->info[i];
|
||||||
if (ti.offset != ctx->size) {
|
if (ti.offset != ctx->size) {
|
||||||
fprintf(stderr, "%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n",
|
GGML_LOG_ERROR("%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n",
|
||||||
__func__, ti.t.name, ti.offset, ctx->size);
|
__func__, ti.t.name, ti.offset, ctx->size);
|
||||||
fprintf(stderr, "%s: failed to read tensor data\n", __func__);
|
GGML_LOG_ERROR("%s: failed to read tensor data\n", __func__);
|
||||||
gguf_free(ctx);
|
gguf_free(ctx);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -679,7 +679,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||||
|
|
||||||
*params.ctx = ggml_init(pdata);
|
*params.ctx = ggml_init(pdata);
|
||||||
if (*params.ctx == nullptr) {
|
if (*params.ctx == nullptr) {
|
||||||
fprintf(stderr, "%s: failed to initialize ggml context for storing tensors\n", __func__);
|
GGML_LOG_ERROR("%s: failed to initialize ggml context for storing tensors\n", __func__);
|
||||||
gguf_free(ctx);
|
gguf_free(ctx);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -701,7 +701,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||||
ok = ok && gr.read(data->data, ctx->size);
|
ok = ok && gr.read(data->data, ctx->size);
|
||||||
|
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
fprintf(stderr, "%s: failed to read tensor data binary blob\n", __func__);
|
GGML_LOG_ERROR("%s: failed to read tensor data binary blob\n", __func__);
|
||||||
ggml_free(ctx_data);
|
ggml_free(ctx_data);
|
||||||
*params.ctx = nullptr;
|
*params.ctx = nullptr;
|
||||||
gguf_free(ctx);
|
gguf_free(ctx);
|
||||||
|
@ -734,7 +734,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
fprintf(stderr, "%s: failed to create tensors\n", __func__);
|
GGML_LOG_ERROR("%s: failed to create tensors\n", __func__);
|
||||||
ggml_free(ctx_data);
|
ggml_free(ctx_data);
|
||||||
*params.ctx = nullptr;
|
*params.ctx = nullptr;
|
||||||
gguf_free(ctx);
|
gguf_free(ctx);
|
||||||
|
@ -751,7 +751,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
||||||
FILE * file = ggml_fopen(fname, "rb");
|
FILE * file = ggml_fopen(fname, "rb");
|
||||||
|
|
||||||
if (!file) {
|
if (!file) {
|
||||||
fprintf(stderr, "%s: failed to open GGUF file '%s'\n", __func__, fname);
|
GGML_LOG_ERROR("%s: failed to open GGUF file '%s'\n", __func__, fname);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1350,7 +1350,7 @@ bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, boo
|
||||||
FILE * file = ggml_fopen(fname, "wb");
|
FILE * file = ggml_fopen(fname, "wb");
|
||||||
|
|
||||||
if (!file) {
|
if (!file) {
|
||||||
fprintf(stderr, "%s: failed to open file '%s' for writing GGUF data\n", __func__, fname);
|
GGML_LOG_ERROR("%s: failed to open file '%s' for writing GGUF data\n", __func__, fname);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -823,6 +823,7 @@ class GGUFEditorWindow(QMainWindow):
|
||||||
self.modified = False
|
self.modified = False
|
||||||
self.metadata_changes = {} # Store changes to apply when saving
|
self.metadata_changes = {} # Store changes to apply when saving
|
||||||
self.metadata_to_remove = set() # Store keys to remove when saving
|
self.metadata_to_remove = set() # Store keys to remove when saving
|
||||||
|
self.on_metadata_changed_is_connected = False
|
||||||
|
|
||||||
self.setup_ui()
|
self.setup_ui()
|
||||||
|
|
||||||
|
@ -941,9 +942,11 @@ class GGUFEditorWindow(QMainWindow):
|
||||||
return
|
return
|
||||||
|
|
||||||
# Disconnect to prevent triggering during loading
|
# Disconnect to prevent triggering during loading
|
||||||
with warnings.catch_warnings():
|
if self.on_metadata_changed_is_connected:
|
||||||
warnings.filterwarnings('ignore')
|
with warnings.catch_warnings():
|
||||||
self.metadata_table.itemChanged.disconnect(self.on_metadata_changed)
|
warnings.filterwarnings('ignore')
|
||||||
|
self.metadata_table.itemChanged.disconnect(self.on_metadata_changed)
|
||||||
|
self.on_metadata_changed_is_connected = False
|
||||||
|
|
||||||
for i, (key, field) in enumerate(self.reader.fields.items()):
|
for i, (key, field) in enumerate(self.reader.fields.items()):
|
||||||
self.metadata_table.insertRow(i)
|
self.metadata_table.insertRow(i)
|
||||||
|
@ -1021,6 +1024,7 @@ class GGUFEditorWindow(QMainWindow):
|
||||||
|
|
||||||
# Reconnect after loading
|
# Reconnect after loading
|
||||||
self.metadata_table.itemChanged.connect(self.on_metadata_changed)
|
self.metadata_table.itemChanged.connect(self.on_metadata_changed)
|
||||||
|
self.on_metadata_changed_is_connected = True
|
||||||
|
|
||||||
def extract_array_values(self, field: ReaderField) -> list:
|
def extract_array_values(self, field: ReaderField) -> list:
|
||||||
"""Extract all values from an array field."""
|
"""Extract all values from an array field."""
|
||||||
|
|
|
@ -68,7 +68,7 @@ class TensorNameMap:
|
||||||
"output_layer", # chatglm
|
"output_layer", # chatglm
|
||||||
"head", # rwkv
|
"head", # rwkv
|
||||||
"head.out", # wavtokenizer
|
"head.out", # wavtokenizer
|
||||||
"language_model.lm_head", # llama4
|
"lm_head", # llama4
|
||||||
),
|
),
|
||||||
|
|
||||||
# Output norm
|
# Output norm
|
||||||
|
@ -91,7 +91,7 @@ class TensorNameMap:
|
||||||
"rwkv.ln_out", # rwkv6
|
"rwkv.ln_out", # rwkv6
|
||||||
"model.ln_out", # rwkv7
|
"model.ln_out", # rwkv7
|
||||||
"backbone.final_layer_norm", # wavtokenizer
|
"backbone.final_layer_norm", # wavtokenizer
|
||||||
"language_model.model.norm", # llama4
|
"model.norm", # llama4
|
||||||
),
|
),
|
||||||
|
|
||||||
# Rope frequencies
|
# Rope frequencies
|
||||||
|
@ -133,7 +133,7 @@ class TensorNameMap:
|
||||||
"transformer.layers.{bid}.attn_norm", # openelm
|
"transformer.layers.{bid}.attn_norm", # openelm
|
||||||
"rwkv.blocks.{bid}.ln1", # rwkv6
|
"rwkv.blocks.{bid}.ln1", # rwkv6
|
||||||
"model.layers.{bid}.ln1", # rwkv7
|
"model.layers.{bid}.ln1", # rwkv7
|
||||||
"language_model.model.layers.{bid}.input_layernorm", # llama4
|
"model.layers.{bid}.input_layernorm", # llama4
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention norm 2
|
# Attention norm 2
|
||||||
|
@ -173,7 +173,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.attention.wq", # internlm2
|
"model.layers.{bid}.attention.wq", # internlm2
|
||||||
"transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
|
"transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
|
||||||
"transformer.h.{bid}.attn.attention.q_proj", # exaone
|
"transformer.h.{bid}.attn.attention.q_proj", # exaone
|
||||||
"language_model.model.layers.{bid}.self_attn.q_proj", # llama4
|
"model.layers.{bid}.self_attn.q_proj", # llama4
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention key
|
# Attention key
|
||||||
|
@ -188,7 +188,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.attention.wk", # internlm2
|
"model.layers.{bid}.attention.wk", # internlm2
|
||||||
"transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
|
"transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
|
||||||
"transformer.h.{bid}.attn.attention.k_proj", # exaone
|
"transformer.h.{bid}.attn.attention.k_proj", # exaone
|
||||||
"language_model.model.layers.{bid}.self_attn.k_proj", # llama4
|
"model.layers.{bid}.self_attn.k_proj", # llama4
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention value
|
# Attention value
|
||||||
|
@ -202,7 +202,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.attention.wv", # internlm2
|
"model.layers.{bid}.attention.wv", # internlm2
|
||||||
"transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok
|
"transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok
|
||||||
"transformer.h.{bid}.attn.attention.v_proj", # exaone
|
"transformer.h.{bid}.attn.attention.v_proj", # exaone
|
||||||
"language_model.model.layers.{bid}.self_attn.v_proj", # llama4
|
"model.layers.{bid}.self_attn.v_proj", # llama4
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention output
|
# Attention output
|
||||||
|
@ -229,7 +229,7 @@ class TensorNameMap:
|
||||||
"encoder.layers.{bid}.self_attention.dense", # chatglm
|
"encoder.layers.{bid}.self_attention.dense", # chatglm
|
||||||
"transformer.layers.{bid}.attn.out_proj", # openelm
|
"transformer.layers.{bid}.attn.out_proj", # openelm
|
||||||
"transformer.h.{bid}.attn.attention.out_proj", # exaone
|
"transformer.h.{bid}.attn.attention.out_proj", # exaone
|
||||||
"language_model.model.layers.{bid}.self_attn.o_proj", # llama4
|
"model.layers.{bid}.self_attn.o_proj", # llama4
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention output norm
|
# Attention output norm
|
||||||
|
@ -268,7 +268,7 @@ class TensorNameMap:
|
||||||
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
||||||
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
|
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
|
||||||
"transformer.layers.{bid}.ffn_norm", # openelm
|
"transformer.layers.{bid}.ffn_norm", # openelm
|
||||||
"language_model.model.layers.{bid}.post_attention_layernorm", # llama4
|
"model.layers.{bid}.post_attention_layernorm", # llama4
|
||||||
),
|
),
|
||||||
|
|
||||||
# Post feed-forward norm
|
# Post feed-forward norm
|
||||||
|
@ -289,7 +289,7 @@ class TensorNameMap:
|
||||||
"transformer.decoder_layer.{bid}.router", # Grok
|
"transformer.decoder_layer.{bid}.router", # Grok
|
||||||
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
|
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
|
||||||
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
|
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
|
||||||
"language_model.model.layers.{bid}.feed_forward.router", # llama4
|
"model.layers.{bid}.feed_forward.router", # llama4
|
||||||
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
|
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
|
||||||
),
|
),
|
||||||
|
|
||||||
|
@ -329,7 +329,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.residual_mlp.w3", # arctic
|
"model.layers.{bid}.residual_mlp.w3", # arctic
|
||||||
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
|
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
|
||||||
"transformer.h.{bid}.mlp.c_fc_1", # exaone
|
"transformer.h.{bid}.mlp.c_fc_1", # exaone
|
||||||
"language_model.model.layers.{bid}.feed_forward.up_proj", # llama4
|
"model.layers.{bid}.feed_forward.up_proj", # llama4
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_UP_EXP: (
|
MODEL_TENSOR.FFN_UP_EXP: (
|
||||||
|
@ -338,14 +338,14 @@ class TensorNameMap:
|
||||||
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
|
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
|
||||||
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
|
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
|
||||||
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
|
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
|
||||||
"language_model.model.layers.{bid}.feed_forward.experts.up_proj", # llama4
|
"model.layers.{bid}.feed_forward.experts.up_proj", # llama4
|
||||||
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
|
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_UP_SHEXP: (
|
MODEL_TENSOR.FFN_UP_SHEXP: (
|
||||||
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
|
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
|
||||||
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
|
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
|
||||||
"language_model.model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
|
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
|
||||||
),
|
),
|
||||||
|
|
||||||
# AWQ-activation gate
|
# AWQ-activation gate
|
||||||
|
@ -366,22 +366,22 @@ class TensorNameMap:
|
||||||
"transformer.h.{bid}.mlp.linear_1", # refact
|
"transformer.h.{bid}.mlp.linear_1", # refact
|
||||||
"model.layers.{bid}.residual_mlp.w1", # arctic
|
"model.layers.{bid}.residual_mlp.w1", # arctic
|
||||||
"transformer.h.{bid}.mlp.c_fc_0", # exaone
|
"transformer.h.{bid}.mlp.c_fc_0", # exaone
|
||||||
"language_model.model.layers.{bid}.feed_forward.gate_proj", # llama4
|
"model.layers.{bid}.feed_forward.gate_proj", # llama4
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||||
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
|
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
|
||||||
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
|
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
|
||||||
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
|
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
|
||||||
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
|
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
|
||||||
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
|
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
|
||||||
"language_model.model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
|
"model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_GATE_SHEXP: (
|
MODEL_TENSOR.FFN_GATE_SHEXP: (
|
||||||
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
|
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
|
||||||
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
|
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
|
||||||
"language_model.model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
|
"model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
|
||||||
),
|
),
|
||||||
|
|
||||||
# Feed-forward down
|
# Feed-forward down
|
||||||
|
@ -410,7 +410,7 @@ class TensorNameMap:
|
||||||
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
|
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
|
||||||
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
|
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
|
||||||
"model.layers.h.{bid}.mlp.c_proj", # exaone
|
"model.layers.h.{bid}.mlp.c_proj", # exaone
|
||||||
"language_model.model.layers.{bid}.feed_forward.down_proj", # llama4
|
"model.layers.{bid}.feed_forward.down_proj", # llama4
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||||
|
@ -420,15 +420,15 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
|
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
|
||||||
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
|
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
|
||||||
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
|
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
|
||||||
"language_model.model.layers.{bid}.feed_forward.experts.down_proj", # llama4
|
"model.layers.{bid}.feed_forward.experts.down_proj", # llama4
|
||||||
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
|
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_DOWN_SHEXP: (
|
MODEL_TENSOR.FFN_DOWN_SHEXP: (
|
||||||
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
|
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
|
||||||
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
|
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
|
||||||
"language_model.model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
|
"model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
|
||||||
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
|
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.ATTN_Q_NORM: (
|
MODEL_TENSOR.ATTN_Q_NORM: (
|
||||||
|
|
|
@ -1704,10 +1704,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
|
||||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||||
|
|
||||||
kv_self->state_write(io);
|
if (kv_self != nullptr) {
|
||||||
|
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
||||||
|
kv_self->state_write(io);
|
||||||
|
}
|
||||||
|
|
||||||
return io.n_bytes();
|
return io.n_bytes();
|
||||||
}
|
}
|
||||||
|
|
|
@ -441,6 +441,13 @@ void llama_kv_cache_unified::defrag_sched(float thold) {
|
||||||
|
|
||||||
void llama_kv_cache_unified::set_full() {
|
void llama_kv_cache_unified::set_full() {
|
||||||
n = size;
|
n = size;
|
||||||
|
|
||||||
|
// when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
|
||||||
|
// affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
|
||||||
|
// we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so
|
||||||
|
// setting it to 0 is the simplest way to achieve that
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/issues/13359
|
||||||
|
head = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sbatch llama_kv_cache_unified::sbatch_init(
|
llama_sbatch llama_kv_cache_unified::sbatch_init(
|
||||||
|
@ -1712,6 +1719,7 @@ void llama_kv_cache_recurrent::defrag_sched(float thold) {
|
||||||
|
|
||||||
void llama_kv_cache_recurrent::set_full() {
|
void llama_kv_cache_recurrent::set_full() {
|
||||||
n = size;
|
n = size;
|
||||||
|
head = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sbatch llama_kv_cache_recurrent::sbatch_init(
|
llama_sbatch llama_kv_cache_recurrent::sbatch_init(
|
||||||
|
|
|
@ -171,11 +171,8 @@ public:
|
||||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||||
|
|
||||||
// Note: The value of head isn't only used to optimize searching
|
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
||||||
// for a free KV slot. llama_decode_impl also uses it, so it
|
uint32_t size = 0; // total number of cells, shared across all sequences
|
||||||
// cannot be freely changed after a slot has been allocated.
|
|
||||||
uint32_t head = 0;
|
|
||||||
uint32_t size = 0;
|
|
||||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||||
|
|
||||||
// computed before each graph build
|
// computed before each graph build
|
||||||
|
@ -343,11 +340,8 @@ public:
|
||||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||||
|
|
||||||
// Note: The value of head isn't only used to optimize searching
|
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
||||||
// for a free KV slot. llama_decode_impl also uses it, so it
|
uint32_t size = 0; // total number of cells, shared across all sequences
|
||||||
// cannot be freely changed after a slot has been allocated.
|
|
||||||
uint32_t head = 0;
|
|
||||||
uint32_t size = 0;
|
|
||||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||||
|
|
||||||
// computed before each graph build
|
// computed before each graph build
|
||||||
|
|
|
@ -473,7 +473,7 @@ llama_model_loader::llama_model_loader(
|
||||||
|
|
||||||
meta.reset(gguf_init_from_file(fname.c_str(), params));
|
meta.reset(gguf_init_from_file(fname.c_str(), params));
|
||||||
if (!meta) {
|
if (!meta) {
|
||||||
throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str()));
|
throw std::runtime_error(format("%s: failed to load model from %s", __func__, fname.c_str()));
|
||||||
}
|
}
|
||||||
|
|
||||||
get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
|
get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
|
||||||
|
@ -532,7 +532,7 @@ llama_model_loader::llama_model_loader(
|
||||||
};
|
};
|
||||||
gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) };
|
gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) };
|
||||||
if (!ctx_gguf) {
|
if (!ctx_gguf) {
|
||||||
throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, fname_split));
|
throw std::runtime_error(format("%s: failed to load GGUF split from %s", __func__, fname_split));
|
||||||
}
|
}
|
||||||
|
|
||||||
// check idx
|
// check idx
|
||||||
|
@ -827,13 +827,18 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps
|
||||||
mappings.reserve(files.size());
|
mappings.reserve(files.size());
|
||||||
mmaps_used.reserve(files.size());
|
mmaps_used.reserve(files.size());
|
||||||
for (const auto & file : files) {
|
for (const auto & file : files) {
|
||||||
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
|
bool is_numa = false;
|
||||||
if (!reg) {
|
|
||||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
auto * dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||||
|
if (dev) {
|
||||||
|
auto * reg = ggml_backend_dev_backend_reg(dev);
|
||||||
|
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
|
||||||
|
if (is_numa_fn) {
|
||||||
|
is_numa = is_numa_fn();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
|
std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa);
|
||||||
std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa_fn());
|
|
||||||
mmaps_used.emplace_back(mapping->size(), 0);
|
mmaps_used.emplace_back(mapping->size(), 0);
|
||||||
if (mlock_mmaps) {
|
if (mlock_mmaps) {
|
||||||
std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock());
|
std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock());
|
||||||
|
|
|
@ -12322,6 +12322,9 @@ struct llm_build_granite : public llm_graph_context {
|
||||||
|
|
||||||
// inp_pos - built only if rope enabled
|
// inp_pos - built only if rope enabled
|
||||||
ggml_tensor * inp_pos = nullptr;
|
ggml_tensor * inp_pos = nullptr;
|
||||||
|
if (use_rope) {
|
||||||
|
inp_pos = build_inp_pos();
|
||||||
|
}
|
||||||
|
|
||||||
auto * inp_attn = build_attn_inp_kv_unified();
|
auto * inp_attn = build_attn_inp_kv_unified();
|
||||||
|
|
||||||
|
@ -12364,10 +12367,6 @@ struct llm_build_granite : public llm_graph_context {
|
||||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
|
|
||||||
if (use_rope) {
|
if (use_rope) {
|
||||||
|
|
||||||
if (!inp_pos) {
|
|
||||||
inp_pos = build_inp_pos();
|
|
||||||
}
|
|
||||||
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
|
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, Qcur, inp_pos, rope_factors,
|
ctx0, Qcur, inp_pos, rope_factors,
|
||||||
|
|
|
@ -14,6 +14,12 @@
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
|
// Quantization types. Changes to this struct must be replicated in quantize.cpp
|
||||||
|
struct tensor_quantization {
|
||||||
|
std::string name;
|
||||||
|
ggml_type quant = GGML_TYPE_COUNT;
|
||||||
|
};
|
||||||
|
|
||||||
static void zeros(std::ofstream & file, size_t n) {
|
static void zeros(std::ofstream & file, size_t n) {
|
||||||
char zero = 0;
|
char zero = 0;
|
||||||
for (size_t i = 0; i < n; ++i) {
|
for (size_t i = 0; i < n; ++i) {
|
||||||
|
@ -48,12 +54,6 @@ struct quantize_state_impl {
|
||||||
{}
|
{}
|
||||||
};
|
};
|
||||||
|
|
||||||
// changes to this struct must be replicated in quantize.cpp
|
|
||||||
struct tensor_quantization {
|
|
||||||
std::string name;
|
|
||||||
ggml_type quant = GGML_TYPE_COUNT;
|
|
||||||
};
|
|
||||||
|
|
||||||
static void llama_tensor_dequantize_impl(
|
static void llama_tensor_dequantize_impl(
|
||||||
ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
|
ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
|
||||||
const size_t nelements, const int nthread
|
const size_t nelements, const int nthread
|
||||||
|
@ -799,17 +799,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
// unless the user specifies a type
|
// unless the user specifies a type
|
||||||
if (params->tensor_types) {
|
if (params->tensor_types) {
|
||||||
const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
|
const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
|
||||||
|
const std::string tensor_name(tensor->name);
|
||||||
for (const auto & [tname, qtype] : tensor_types) {
|
for (const auto & [tname, qtype] : tensor_types) {
|
||||||
if (std::regex pattern(tname); std::regex_search(tensor->name, pattern)) {
|
if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) {
|
||||||
if (qtype != new_type) {
|
if (qtype != new_type) {
|
||||||
LLAMA_LOG_DEBUG("(overriding %s -> %s), ", ggml_type_name(new_type), ggml_type_name(qtype));
|
LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type));
|
||||||
|
new_type = qtype;
|
||||||
|
break; // if two or more types are specified for the tensor, first match wins
|
||||||
}
|
}
|
||||||
new_type = qtype;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
|
if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
|
||||||
new_type = params->token_embedding_type;
|
new_type = params->token_embedding_type;
|
||||||
}
|
}
|
||||||
|
|
288
tests/test-regex-partial.cpp
Normal file
288
tests/test-regex-partial.cpp
Normal file
|
@ -0,0 +1,288 @@
|
||||||
|
// Tests common_regex (esp. its partial final matches support).
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
#include "regex-partial.h"
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
template <class T> static void assert_equals(const T & expected, const T & actual) {
|
||||||
|
if (expected != actual) {
|
||||||
|
std::cerr << "Expected: " << expected << std::endl;
|
||||||
|
std::cerr << " Actual: " << actual << std::endl;
|
||||||
|
std::cerr << std::flush;
|
||||||
|
throw std::runtime_error("Test failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct test_case {
|
||||||
|
std::string pattern;
|
||||||
|
struct input_output {
|
||||||
|
std::string input;
|
||||||
|
common_regex_match output;
|
||||||
|
};
|
||||||
|
std::vector<input_output> inputs_outputs;
|
||||||
|
};
|
||||||
|
|
||||||
|
static std::string common_regex_match_type_name(common_regex_match_type type) {
|
||||||
|
switch (type) {
|
||||||
|
case COMMON_REGEX_MATCH_TYPE_NONE:
|
||||||
|
return "COMMON_REGEX_MATCH_TYPE_NONE";
|
||||||
|
case COMMON_REGEX_MATCH_TYPE_PARTIAL:
|
||||||
|
return "COMMON_REGEX_MATCH_TYPE_PARTIAL";
|
||||||
|
case COMMON_REGEX_MATCH_TYPE_FULL:
|
||||||
|
return "COMMON_REGEX_MATCH_TYPE_FULL";
|
||||||
|
}
|
||||||
|
return "?";
|
||||||
|
}
|
||||||
|
|
||||||
|
static void test_regex() {
|
||||||
|
printf("[%s]\n", __func__);
|
||||||
|
auto test = [](const test_case & test_case) {
|
||||||
|
common_regex cr(test_case.pattern);
|
||||||
|
std::cout << "Testing pattern: /" << test_case.pattern << "/\n";
|
||||||
|
// std::cout << " partial rev: " << cr.reversed_partial_pattern.str() << '\n';
|
||||||
|
for (const auto & input_output : test_case.inputs_outputs) {
|
||||||
|
std::cout << " Input: " << input_output.input << '\n';
|
||||||
|
auto m = cr.search(input_output.input, 0);
|
||||||
|
if (m != input_output.output) {
|
||||||
|
auto match_to_str = [&](const std::optional<common_regex_match> & m) {
|
||||||
|
std::ostringstream ss;
|
||||||
|
if (m->type == COMMON_REGEX_MATCH_TYPE_NONE) {
|
||||||
|
ss << "<no match>";
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(!input_output.output.groups.empty());
|
||||||
|
std::vector<std::string> parts;
|
||||||
|
for (const auto & g : m->groups) {
|
||||||
|
parts.push_back("{" + std::to_string(g.begin) + ", " + std::to_string(g.end) + "}");
|
||||||
|
}
|
||||||
|
ss << "{" << common_regex_match_type_name(m->type) << ", {" << string_join(parts, ", ") << "}}";
|
||||||
|
}
|
||||||
|
return ss.str();
|
||||||
|
};
|
||||||
|
std::cout << " Expected: " << match_to_str(input_output.output) << '\n';
|
||||||
|
std::cout << " Got: " << match_to_str(m) << '\n';
|
||||||
|
std::cout << " Inverted pattern: /" << regex_to_reversed_partial_regex(test_case.pattern) << "/\n";
|
||||||
|
|
||||||
|
throw std::runtime_error("Test failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
test({
|
||||||
|
"a",
|
||||||
|
{
|
||||||
|
{"a", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
|
||||||
|
{"b", {COMMON_REGEX_MATCH_TYPE_NONE, {}}},
|
||||||
|
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
|
||||||
|
{"ba", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 2}}}},
|
||||||
|
}
|
||||||
|
});
|
||||||
|
test({
|
||||||
|
"abcd",
|
||||||
|
{
|
||||||
|
{"abcd", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
|
||||||
|
{"abcde", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
|
||||||
|
{"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
|
||||||
|
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||||
|
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
|
||||||
|
{"d", {}},
|
||||||
|
{"bcd", {}},
|
||||||
|
{"cde", {}},
|
||||||
|
{"cd", {}},
|
||||||
|
{"yeah ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{5, 7}}}},
|
||||||
|
{"abbie", {}},
|
||||||
|
{"", {}},
|
||||||
|
}
|
||||||
|
});
|
||||||
|
test({
|
||||||
|
".*?ab",
|
||||||
|
{
|
||||||
|
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
|
||||||
|
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
|
||||||
|
{"dab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
|
||||||
|
{"dabc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
|
||||||
|
{"da", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||||
|
{"d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
|
||||||
|
}
|
||||||
|
});
|
||||||
|
test({
|
||||||
|
"a.*?b",
|
||||||
|
{
|
||||||
|
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
|
||||||
|
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
|
||||||
|
{"a b", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
|
||||||
|
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
|
||||||
|
{"argh", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
|
||||||
|
{"d", {}},
|
||||||
|
{"b", {}},
|
||||||
|
}
|
||||||
|
});
|
||||||
|
test({
|
||||||
|
"ab(?:cd){2,4}ef",
|
||||||
|
{
|
||||||
|
// {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, 0, {}}},
|
||||||
|
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||||
|
{"abcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
|
||||||
|
{"abcde", {}},
|
||||||
|
{"abcdef", {}},
|
||||||
|
{"abcdcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
|
||||||
|
{"abcdcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 7}}}},
|
||||||
|
{"abcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
|
||||||
|
{"abcdcdcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 12}}}},
|
||||||
|
{"abcdcdcdcdcdef", {}},
|
||||||
|
{"abcde", {}},
|
||||||
|
{"yea", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{2, 3}}}},
|
||||||
|
}
|
||||||
|
});
|
||||||
|
test({
|
||||||
|
"a(?:rte| pure )fact",
|
||||||
|
{
|
||||||
|
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
|
||||||
|
{"art", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
|
||||||
|
{"artefa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
|
||||||
|
{"fact", {}},
|
||||||
|
{"an arte", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{3, 7}}}},
|
||||||
|
{"artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
|
||||||
|
{"an artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{3, 11}}}},
|
||||||
|
{"a pure", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
|
||||||
|
{"a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 11}}}},
|
||||||
|
{"it's a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{5, 16}}}},
|
||||||
|
{"" , {}},
|
||||||
|
{"pure", {}},
|
||||||
|
{"pure fact", {}},
|
||||||
|
}
|
||||||
|
});
|
||||||
|
test({
|
||||||
|
"abc",
|
||||||
|
{
|
||||||
|
{" abcc", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 4}}}},
|
||||||
|
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||||
|
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
|
||||||
|
{" ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{1, 3}}}},
|
||||||
|
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
|
||||||
|
{"b", {}},
|
||||||
|
{"c", {}},
|
||||||
|
{"", {}},
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
test({
|
||||||
|
"(?:abc)?\\s*def",
|
||||||
|
{
|
||||||
|
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||||
|
{"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
|
||||||
|
{"abc ", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
|
||||||
|
{"abc d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
|
||||||
|
{"abc de", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
|
||||||
|
{"abc def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
|
||||||
|
{"abc defg", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
|
||||||
|
{"abc defgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
|
||||||
|
{"abcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
|
||||||
|
{"abcdefgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 6}}}},
|
||||||
|
{" d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||||
|
{"def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
test({
|
||||||
|
"a+b",
|
||||||
|
{
|
||||||
|
{"aaab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
|
||||||
|
{"aaa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
|
||||||
|
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
test({
|
||||||
|
"(?:"
|
||||||
|
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
|
||||||
|
"(" // match 2 (open_tag)
|
||||||
|
"<tool_call>"
|
||||||
|
"|<function_call>"
|
||||||
|
"|<tool>"
|
||||||
|
"|<tools>"
|
||||||
|
"|<response>"
|
||||||
|
"|<json>"
|
||||||
|
"|<xml>"
|
||||||
|
"|<JSON>"
|
||||||
|
")?"
|
||||||
|
"(\\s*\\{\\s*\"name\"\\s*:)" // match 3 (named tool call)
|
||||||
|
")"
|
||||||
|
"|<function=([^>]+)>" // match 4 (function name)
|
||||||
|
"|<function name=\"([^\"]+)\">", // match 5 (function name again)
|
||||||
|
{
|
||||||
|
{"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}, {54, 54}, {54, 54}, {0, 8}, {54, 54}, {54, 54}}}},
|
||||||
|
{"<tool_call> {\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 18}}}},
|
||||||
|
{"<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 17}}}},
|
||||||
|
{"Let's call something\n<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{21, 38}}}},
|
||||||
|
{"Ok then<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 24}}}},
|
||||||
|
{"{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
|
||||||
|
{"Ok then{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 13}}}},
|
||||||
|
{"<tool_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 20}, {66, 66}, {0, 11}, {11, 20}, {66, 66}, {66, 66}}}},
|
||||||
|
{"<function_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 24}, {70, 70}, {0, 15}, {15, 24}, {70, 70}, {70, 70}}}},
|
||||||
|
{"<function name=\"special_function\"> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 34}, {89, 89}, {89, 89}, {89, 89}, {89, 89}, {16, 32}}}},
|
||||||
|
{"<function=all>", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 14}, {14, 14}, {14, 14}, {14, 14}, {10, 13}, {14, 14}}}},
|
||||||
|
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static void test_regex_to_reversed_partial_regex() {
|
||||||
|
printf("[%s]\n", __func__);
|
||||||
|
|
||||||
|
assert_equals<std::string>(
|
||||||
|
"((?:(?:c)?b)?a)[\\s\\S]*",
|
||||||
|
regex_to_reversed_partial_regex("abc"));
|
||||||
|
|
||||||
|
assert_equals<std::string>(
|
||||||
|
"(a+)[\\s\\S]*",
|
||||||
|
regex_to_reversed_partial_regex("a+"));
|
||||||
|
|
||||||
|
assert_equals<std::string>(
|
||||||
|
"(a*)[\\s\\S]*",
|
||||||
|
regex_to_reversed_partial_regex("a*"));
|
||||||
|
|
||||||
|
assert_equals<std::string>(
|
||||||
|
"(a?)[\\s\\S]*",
|
||||||
|
regex_to_reversed_partial_regex("a?"));
|
||||||
|
|
||||||
|
assert_equals<std::string>(
|
||||||
|
"([a-z])[\\s\\S]*",
|
||||||
|
regex_to_reversed_partial_regex("[a-z]"));
|
||||||
|
|
||||||
|
assert_equals<std::string>(
|
||||||
|
"((?:\\w+)?[a-z])[\\s\\S]*",
|
||||||
|
regex_to_reversed_partial_regex("[a-z]\\w+"));
|
||||||
|
|
||||||
|
assert_equals<std::string>(
|
||||||
|
"((?:a|b))[\\s\\S]*",
|
||||||
|
regex_to_reversed_partial_regex("(?:a|b)"));
|
||||||
|
assert_equals<std::string>(
|
||||||
|
"((?:(?:(?:d)?c)?b)?a)[\\s\\S]*",
|
||||||
|
regex_to_reversed_partial_regex("abcd"));
|
||||||
|
assert_equals<std::string>(
|
||||||
|
"((?:b)?a*)[\\s\\S]*", // TODO: ((?:b)?a*+).* ??
|
||||||
|
regex_to_reversed_partial_regex("a*b"));
|
||||||
|
assert_equals<std::string>(
|
||||||
|
"((?:(?:b)?a)?.*)[\\s\\S]*",
|
||||||
|
regex_to_reversed_partial_regex(".*?ab"));
|
||||||
|
assert_equals<std::string>(
|
||||||
|
"((?:(?:b)?.*)?a)[\\s\\S]*",
|
||||||
|
regex_to_reversed_partial_regex("a.*?b"));
|
||||||
|
assert_equals<std::string>(
|
||||||
|
"((?:(?:d)?(?:(?:c)?b))?a)[\\s\\S]*",
|
||||||
|
regex_to_reversed_partial_regex("a(bc)d"));
|
||||||
|
assert_equals<std::string>(
|
||||||
|
"((?:(?:(?:c)?b|(?:e)?d))?a)[\\s\\S]*",
|
||||||
|
regex_to_reversed_partial_regex("a(bc|de)"));
|
||||||
|
assert_equals<std::string>(
|
||||||
|
"((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)[\\s\\S]*",
|
||||||
|
regex_to_reversed_partial_regex("ab{2,4}c"));
|
||||||
|
}
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
test_regex_to_reversed_partial_regex();
|
||||||
|
test_regex();
|
||||||
|
std::cout << "All tests passed.\n";
|
||||||
|
}
|
|
@ -58,6 +58,12 @@ static const std::vector<quant_option> QUANT_OPTIONS = {
|
||||||
{ "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", },
|
{ "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", },
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Quantization types. Changes to this struct must be replicated in llama-quantize.cpp
|
||||||
|
struct tensor_quantization {
|
||||||
|
std::string name;
|
||||||
|
ggml_type quant = GGML_TYPE_COUNT;
|
||||||
|
};
|
||||||
|
|
||||||
static const char * const LLM_KV_QUANTIZE_IMATRIX_FILE = "quantize.imatrix.file";
|
static const char * const LLM_KV_QUANTIZE_IMATRIX_FILE = "quantize.imatrix.file";
|
||||||
static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix.dataset";
|
static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix.dataset";
|
||||||
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count";
|
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count";
|
||||||
|
@ -245,56 +251,10 @@ static ggml_type parse_ggml_type(const char * arg) {
|
||||||
return type;
|
return type;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fprintf(stderr, "%s: invalid ggml_type '%s'\n", __func__, arg);
|
fprintf(stderr, "\n%s: invalid ggml_type '%s'\n\n", __func__, arg);
|
||||||
return GGML_TYPE_COUNT;
|
return GGML_TYPE_COUNT;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allowed tensors for arbitrary quantization with --tensor-type option
|
|
||||||
static const std::vector<std::string> ALLOWED_TENSOR_TYPE = {
|
|
||||||
"attn_k",
|
|
||||||
"attn_kv_a_mqa",
|
|
||||||
"attn_kv_b",
|
|
||||||
"attn_o",
|
|
||||||
"attn_output",
|
|
||||||
"attn_q",
|
|
||||||
"attn_q_a",
|
|
||||||
"attn_q_b",
|
|
||||||
"attn_qkv",
|
|
||||||
"attn_v",
|
|
||||||
"channel_mix_key",
|
|
||||||
"channel_mix_receptance",
|
|
||||||
"channel_mix_value",
|
|
||||||
"cls",
|
|
||||||
"cls.output",
|
|
||||||
"cross_attn_k",
|
|
||||||
"cross_attn_o",
|
|
||||||
"cross_attn_q",
|
|
||||||
"cross_attn_v",
|
|
||||||
"ffn_act",
|
|
||||||
"ffn_down",
|
|
||||||
"ffn_down_exps",
|
|
||||||
"ffn_down_shexp",
|
|
||||||
"ffn_gate",
|
|
||||||
"ffn_gate_exps",
|
|
||||||
"ffn_gate_shexp",
|
|
||||||
"ffn_up",
|
|
||||||
"ffn_up_exps",
|
|
||||||
"ffn_up_shexp",
|
|
||||||
"ssm_in",
|
|
||||||
"ssm_out",
|
|
||||||
"time_mix_gate",
|
|
||||||
"time_mix_key",
|
|
||||||
"time_mix_output",
|
|
||||||
"time_mix_receptance",
|
|
||||||
"time_mix_value",
|
|
||||||
};
|
|
||||||
|
|
||||||
// changes to this struct must be replicated in llama-quant.cpp
|
|
||||||
struct tensor_quantization {
|
|
||||||
std::string name;
|
|
||||||
ggml_type quant = GGML_TYPE_COUNT;
|
|
||||||
};
|
|
||||||
|
|
||||||
static bool parse_tensor_type(const char * data, std::vector<tensor_quantization> & tensor_type) {
|
static bool parse_tensor_type(const char * data, std::vector<tensor_quantization> & tensor_type) {
|
||||||
const char * sep = strchr(data, '=');
|
const char * sep = strchr(data, '=');
|
||||||
if (sep == nullptr) {
|
if (sep == nullptr) {
|
||||||
|
@ -307,7 +267,6 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
|
||||||
printf("\n%s: missing tensor name\n\n", __func__);
|
printf("\n%s: missing tensor name\n\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (const size_t qt_len = strlen(sep); qt_len == 1) {
|
if (const size_t qt_len = strlen(sep); qt_len == 1) {
|
||||||
printf("\n%s: missing quantization type\n\n", __func__);
|
printf("\n%s: missing quantization type\n\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
|
@ -316,37 +275,15 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
|
||||||
std::string tn(data, tn_len);
|
std::string tn(data, tn_len);
|
||||||
std::transform(tn.begin(), tn.end(), tn.begin(), tolower);
|
std::transform(tn.begin(), tn.end(), tn.begin(), tolower);
|
||||||
sep++;
|
sep++;
|
||||||
const std::string qt(sep);
|
|
||||||
|
|
||||||
bool found = false;
|
|
||||||
for (const auto & allowed : ALLOWED_TENSOR_TYPE) {
|
|
||||||
std::string tensor;
|
|
||||||
tensor = tn.rfind('.') != std::string::npos ? tn.substr(tn.rfind('.') + 1) : tn;
|
|
||||||
// handle special case of cls.output
|
|
||||||
std::string cls_output = "cls.output";
|
|
||||||
if (tn.find(cls_output) != std::string::npos) {
|
|
||||||
tensor = "cls.output";
|
|
||||||
}
|
|
||||||
// check if an allowed tensor exists and it's at the end of the kv string
|
|
||||||
if (tensor == allowed) {
|
|
||||||
found = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!found) {
|
|
||||||
printf("\n%s: invalid tensor name '%s'\n\n", __func__, tn.c_str());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (parse_ggml_type(qt.c_str()) == GGML_TYPE_COUNT) {
|
|
||||||
printf("\n%s: invalid quantization type '%s'\n\n", __func__, qt.c_str());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
tensor_quantization tqz;
|
tensor_quantization tqz;
|
||||||
tqz.name = tn;
|
tqz.name = tn;
|
||||||
tqz.quant = parse_ggml_type(qt.c_str());
|
tqz.quant = parse_ggml_type(sep);
|
||||||
tensor_type.emplace_back(std::move(tqz));
|
tensor_type.emplace_back(std::move(tqz));
|
||||||
|
if (tqz.quant == GGML_TYPE_COUNT) {
|
||||||
|
printf("\n%s: invalid quantization type '%s'\n\n", __func__, sep);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Binary file not shown.
|
@ -1429,7 +1429,7 @@ struct server_slot {
|
||||||
pos = text.find(word, from_pos);
|
pos = text.find(word, from_pos);
|
||||||
} else {
|
} else {
|
||||||
// otherwise, partial stop
|
// otherwise, partial stop
|
||||||
pos = find_partial_stop_string(word, text);
|
pos = string_find_partial_stop(text, word);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
|
if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
|
||||||
|
@ -2951,7 +2951,8 @@ struct server_context {
|
||||||
llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
|
llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
|
||||||
llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
|
llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
|
||||||
|
|
||||||
if (slot.params.cache_prompt) {
|
// add generated tokens to cache
|
||||||
|
{
|
||||||
llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy
|
llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy
|
||||||
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
|
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
|
||||||
new_tokens[i - n_discard] = new_tokens[i];
|
new_tokens[i - n_discard] = new_tokens[i];
|
||||||
|
@ -2996,10 +2997,7 @@ struct server_context {
|
||||||
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
|
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
|
||||||
|
|
||||||
slot.n_past += 1;
|
slot.n_past += 1;
|
||||||
|
slot.cache_tokens.push_back(slot.sampled);
|
||||||
if (slot.params.cache_prompt) {
|
|
||||||
slot.cache_tokens.push_back(slot.sampled);
|
|
||||||
}
|
|
||||||
|
|
||||||
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
|
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
|
||||||
slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
|
slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
|
||||||
|
@ -3171,6 +3169,11 @@ struct server_context {
|
||||||
|
|
||||||
SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
|
SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// if we don't cache the prompt, we have to remove the entire KV cache
|
||||||
|
llama_kv_self_seq_rm(ctx, slot.id, 0, -1);
|
||||||
|
slot.n_past = 0;
|
||||||
|
slot.cache_tokens.clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3204,7 +3207,7 @@ struct server_context {
|
||||||
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
|
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
|
||||||
|
|
||||||
// remove the non-common part from the cache
|
// remove the non-common part from the cache
|
||||||
slot.cache_tokens.resize(slot.n_past);
|
slot.cache_tokens.keep_first(slot.n_past);
|
||||||
|
|
||||||
// check if we should process the image
|
// check if we should process the image
|
||||||
if (slot.n_past < slot.n_prompt_tokens
|
if (slot.n_past < slot.n_prompt_tokens
|
||||||
|
@ -3221,7 +3224,8 @@ struct server_context {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.params.cache_prompt) {
|
// add the image chunk to cache
|
||||||
|
{
|
||||||
const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past);
|
const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past);
|
||||||
slot.cache_tokens.push_back(chunk.get()); // copy
|
slot.cache_tokens.push_back(chunk.get()); // copy
|
||||||
}
|
}
|
||||||
|
@ -3242,9 +3246,7 @@ struct server_context {
|
||||||
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
|
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
|
||||||
|
|
||||||
common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd);
|
common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd);
|
||||||
if (slot.params.cache_prompt) {
|
slot.cache_tokens.push_back(cur_tok);
|
||||||
slot.cache_tokens.push_back(cur_tok);
|
|
||||||
}
|
|
||||||
|
|
||||||
slot.n_prompt_tokens_processed++;
|
slot.n_prompt_tokens_processed++;
|
||||||
slot.n_past++;
|
slot.n_past++;
|
||||||
|
@ -3705,6 +3707,9 @@ int main(int argc, char ** argv) {
|
||||||
if (req.path == "/" || tmp.back() == "html") {
|
if (req.path == "/" || tmp.back() == "html") {
|
||||||
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
|
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
|
||||||
res.status = 503;
|
res.status = 503;
|
||||||
|
} else if (req.path == "/models" || req.path == "/v1/models") {
|
||||||
|
// allow the models endpoint to be accessed during loading
|
||||||
|
return true;
|
||||||
} else {
|
} else {
|
||||||
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
||||||
}
|
}
|
||||||
|
@ -4363,7 +4368,13 @@ int main(int argc, char ** argv) {
|
||||||
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
|
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
const auto handle_models = [¶ms, &ctx_server, &state, &res_ok](const httplib::Request &, httplib::Response & res) {
|
||||||
|
server_state current_state = state.load();
|
||||||
|
json model_meta = nullptr;
|
||||||
|
if (current_state == SERVER_STATE_READY) {
|
||||||
|
model_meta = ctx_server.model_meta();
|
||||||
|
}
|
||||||
|
|
||||||
json models = {
|
json models = {
|
||||||
{"object", "list"},
|
{"object", "list"},
|
||||||
{"data", {
|
{"data", {
|
||||||
|
@ -4372,7 +4383,7 @@ int main(int argc, char ** argv) {
|
||||||
{"object", "model"},
|
{"object", "model"},
|
||||||
{"created", std::time(0)},
|
{"created", std::time(0)},
|
||||||
{"owned_by", "llamacpp"},
|
{"owned_by", "llamacpp"},
|
||||||
{"meta", ctx_server.model_meta()}
|
{"meta", model_meta},
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
};
|
};
|
||||||
|
|
|
@ -196,6 +196,18 @@ def test_cache_vs_nocache_prompt():
|
||||||
assert res_cache.body["content"] == res_no_cache.body["content"]
|
assert res_cache.body["content"] == res_no_cache.body["content"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_nocache_long_input_prompt():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/completion", data={
|
||||||
|
"prompt": "I believe the meaning of life is"*32,
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"cache_prompt": False,
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
def test_completion_with_tokens_input():
|
def test_completion_with_tokens_input():
|
||||||
global server
|
global server
|
||||||
server.temperature = 0.0
|
server.temperature = 0.0
|
||||||
|
|
49
tools/server/tests/unit/test_template.py
Normal file
49
tools/server/tests/unit/test_template.py
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# ensure grandparent path is in sys.path
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from unit.test_tool_call import TEST_TOOL
|
||||||
|
path = Path(__file__).resolve().parents[1]
|
||||||
|
sys.path.insert(0, str(path))
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
from utils import *
|
||||||
|
|
||||||
|
server: ServerProcess
|
||||||
|
|
||||||
|
TIMEOUT_SERVER_START = 15*60
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def create_server():
|
||||||
|
global server
|
||||||
|
server = ServerPreset.tinyllama2()
|
||||||
|
server.model_alias = "tinyllama-2"
|
||||||
|
server.server_port = 8081
|
||||||
|
server.n_slots = 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
|
||||||
|
@pytest.mark.parametrize("template_name,format", [
|
||||||
|
("meta-llama-Llama-3.3-70B-Instruct", "%d %b %Y"),
|
||||||
|
("fireworks-ai-llama-3-firefunction-v2", "%b %d %Y"),
|
||||||
|
])
|
||||||
|
def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
|
||||||
|
global server
|
||||||
|
server.jinja = True
|
||||||
|
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||||
|
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||||
|
|
||||||
|
res = server.make_request("POST", "/apply-template", data={
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What is today?"},
|
||||||
|
],
|
||||||
|
"tools": tools,
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
prompt = res.body["prompt"]
|
||||||
|
|
||||||
|
today_str = datetime.date.today().strftime(format)
|
||||||
|
assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"
|
|
@ -109,7 +109,7 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
|
||||||
])
|
])
|
||||||
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
|
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
|
||||||
global server
|
global server
|
||||||
n_predict = 512
|
n_predict = 1024
|
||||||
# server = ServerPreset.stories15m_moe()
|
# server = ServerPreset.stories15m_moe()
|
||||||
server.jinja = True
|
server.jinja = True
|
||||||
server.n_predict = n_predict
|
server.n_predict = n_predict
|
||||||
|
|
|
@ -643,6 +643,18 @@ static json oaicompat_completion_params_parse(
|
||||||
throw std::runtime_error("Expected 'messages' to be an array");
|
throw std::runtime_error("Expected 'messages' to be an array");
|
||||||
}
|
}
|
||||||
for (auto & msg : messages) {
|
for (auto & msg : messages) {
|
||||||
|
std::string role = json_value(msg, "role", std::string());
|
||||||
|
if (role != "assistant" && !msg.contains("content")) {
|
||||||
|
throw std::runtime_error("All non-assistant messages must contain 'content'");
|
||||||
|
}
|
||||||
|
if (role == "assistant") {
|
||||||
|
if (!msg.contains("content") && !msg.contains("tool_calls")) {
|
||||||
|
throw std::runtime_error("Assistant message must contain either 'content' or 'tool_calls'!");
|
||||||
|
}
|
||||||
|
if (!msg.contains("content")) {
|
||||||
|
continue; // avoid errors with no content
|
||||||
|
}
|
||||||
|
}
|
||||||
json & content = msg.at("content");
|
json & content = msg.at("content");
|
||||||
if (content.is_string() || content.is_null()) {
|
if (content.is_string() || content.is_null()) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -1153,7 +1165,7 @@ public:
|
||||||
tokens.clear();
|
tokens.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void resize(size_t n) {
|
void keep_first(size_t n) {
|
||||||
GGML_ASSERT(n <= tokens.size());
|
GGML_ASSERT(n <= tokens.size());
|
||||||
if (has_mtmd) {
|
if (has_mtmd) {
|
||||||
// we throw an error if we try to remove a token in the middle of an image
|
// we throw an error if we try to remove a token in the middle of an image
|
||||||
|
|
286
tools/server/webui/package-lock.json
generated
286
tools/server/webui/package-lock.json
generated
|
@ -18,6 +18,7 @@
|
||||||
"dexie": "^4.0.11",
|
"dexie": "^4.0.11",
|
||||||
"highlight.js": "^11.10.0",
|
"highlight.js": "^11.10.0",
|
||||||
"katex": "^0.16.15",
|
"katex": "^0.16.15",
|
||||||
|
"pdfjs-dist": "^5.2.133",
|
||||||
"postcss": "^8.4.49",
|
"postcss": "^8.4.49",
|
||||||
"react": "^18.3.1",
|
"react": "^18.3.1",
|
||||||
"react-dom": "^18.3.1",
|
"react-dom": "^18.3.1",
|
||||||
|
@ -44,6 +45,7 @@
|
||||||
"eslint": "^9.17.0",
|
"eslint": "^9.17.0",
|
||||||
"eslint-plugin-react-hooks": "^5.0.0",
|
"eslint-plugin-react-hooks": "^5.0.0",
|
||||||
"eslint-plugin-react-refresh": "^0.4.16",
|
"eslint-plugin-react-refresh": "^0.4.16",
|
||||||
|
"fflate": "^0.8.2",
|
||||||
"globals": "^15.14.0",
|
"globals": "^15.14.0",
|
||||||
"prettier": "^3.4.2",
|
"prettier": "^3.4.2",
|
||||||
"sass-embedded": "^1.83.4",
|
"sass-embedded": "^1.83.4",
|
||||||
|
@ -987,7 +989,7 @@
|
||||||
"version": "0.3.8",
|
"version": "0.3.8",
|
||||||
"resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.8.tgz",
|
"resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.8.tgz",
|
||||||
"integrity": "sha512-imAbBGkb+ebQyxKgzv5Hu2nmROxoDOXHh80evxdoXNOrvAnVx7zimzc1Oo5h9RlfV4vPXaE2iM5pOFbvOCClWA==",
|
"integrity": "sha512-imAbBGkb+ebQyxKgzv5Hu2nmROxoDOXHh80evxdoXNOrvAnVx7zimzc1Oo5h9RlfV4vPXaE2iM5pOFbvOCClWA==",
|
||||||
"dev": true,
|
"devOptional": true,
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@jridgewell/set-array": "^1.2.1",
|
"@jridgewell/set-array": "^1.2.1",
|
||||||
|
@ -1002,7 +1004,7 @@
|
||||||
"version": "3.1.2",
|
"version": "3.1.2",
|
||||||
"resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz",
|
"resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz",
|
||||||
"integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==",
|
"integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==",
|
||||||
"dev": true,
|
"devOptional": true,
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"engines": {
|
"engines": {
|
||||||
"node": ">=6.0.0"
|
"node": ">=6.0.0"
|
||||||
|
@ -1012,30 +1014,224 @@
|
||||||
"version": "1.2.1",
|
"version": "1.2.1",
|
||||||
"resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.2.1.tgz",
|
"resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.2.1.tgz",
|
||||||
"integrity": "sha512-R8gLRTZeyp03ymzP/6Lil/28tGeGEzhx1q2k703KGWRAI1VdvPIXdG70VJc2pAMw3NA6JKL5hhFu1sJX0Mnn/A==",
|
"integrity": "sha512-R8gLRTZeyp03ymzP/6Lil/28tGeGEzhx1q2k703KGWRAI1VdvPIXdG70VJc2pAMw3NA6JKL5hhFu1sJX0Mnn/A==",
|
||||||
"dev": true,
|
"devOptional": true,
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"engines": {
|
"engines": {
|
||||||
"node": ">=6.0.0"
|
"node": ">=6.0.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@jridgewell/source-map": {
|
||||||
|
"version": "0.3.6",
|
||||||
|
"resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.6.tgz",
|
||||||
|
"integrity": "sha512-1ZJTZebgqllO79ue2bm3rIGud/bOe0pP5BjSRCRxxYkEZS8STV7zN84UBbiYu7jy+eCKSnVIUgoWWE/tt+shMQ==",
|
||||||
|
"license": "MIT",
|
||||||
|
"optional": true,
|
||||||
|
"peer": true,
|
||||||
|
"dependencies": {
|
||||||
|
"@jridgewell/gen-mapping": "^0.3.5",
|
||||||
|
"@jridgewell/trace-mapping": "^0.3.25"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/@jridgewell/sourcemap-codec": {
|
"node_modules/@jridgewell/sourcemap-codec": {
|
||||||
"version": "1.5.0",
|
"version": "1.5.0",
|
||||||
"resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.0.tgz",
|
"resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.0.tgz",
|
||||||
"integrity": "sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==",
|
"integrity": "sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==",
|
||||||
"dev": true,
|
"devOptional": true,
|
||||||
"license": "MIT"
|
"license": "MIT"
|
||||||
},
|
},
|
||||||
"node_modules/@jridgewell/trace-mapping": {
|
"node_modules/@jridgewell/trace-mapping": {
|
||||||
"version": "0.3.25",
|
"version": "0.3.25",
|
||||||
"resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.25.tgz",
|
"resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.25.tgz",
|
||||||
"integrity": "sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==",
|
"integrity": "sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==",
|
||||||
"dev": true,
|
"devOptional": true,
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@jridgewell/resolve-uri": "^3.1.0",
|
"@jridgewell/resolve-uri": "^3.1.0",
|
||||||
"@jridgewell/sourcemap-codec": "^1.4.14"
|
"@jridgewell/sourcemap-codec": "^1.4.14"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@napi-rs/canvas": {
|
||||||
|
"version": "0.1.70",
|
||||||
|
"resolved": "https://registry.npmjs.org/@napi-rs/canvas/-/canvas-0.1.70.tgz",
|
||||||
|
"integrity": "sha512-nD6NGa4JbNYSZYsTnLGrqe9Kn/lCkA4ybXt8sx5ojDqZjr2i0TWAHxx/vhgfjX+i3hCdKWufxYwi7CfXqtITSA==",
|
||||||
|
"license": "MIT",
|
||||||
|
"optional": true,
|
||||||
|
"engines": {
|
||||||
|
"node": ">= 10"
|
||||||
|
},
|
||||||
|
"optionalDependencies": {
|
||||||
|
"@napi-rs/canvas-android-arm64": "0.1.70",
|
||||||
|
"@napi-rs/canvas-darwin-arm64": "0.1.70",
|
||||||
|
"@napi-rs/canvas-darwin-x64": "0.1.70",
|
||||||
|
"@napi-rs/canvas-linux-arm-gnueabihf": "0.1.70",
|
||||||
|
"@napi-rs/canvas-linux-arm64-gnu": "0.1.70",
|
||||||
|
"@napi-rs/canvas-linux-arm64-musl": "0.1.70",
|
||||||
|
"@napi-rs/canvas-linux-riscv64-gnu": "0.1.70",
|
||||||
|
"@napi-rs/canvas-linux-x64-gnu": "0.1.70",
|
||||||
|
"@napi-rs/canvas-linux-x64-musl": "0.1.70",
|
||||||
|
"@napi-rs/canvas-win32-x64-msvc": "0.1.70"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@napi-rs/canvas-android-arm64": {
|
||||||
|
"version": "0.1.70",
|
||||||
|
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-android-arm64/-/canvas-android-arm64-0.1.70.tgz",
|
||||||
|
"integrity": "sha512-I/YOuQ0wbkVYxVaYtCgN42WKTYxNqFA0gTcTrHIGG1jfpDSyZWII/uHcjOo4nzd19io6Y4+/BqP8E5hJgf9OmQ==",
|
||||||
|
"cpu": [
|
||||||
|
"arm64"
|
||||||
|
],
|
||||||
|
"license": "MIT",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"android"
|
||||||
|
],
|
||||||
|
"engines": {
|
||||||
|
"node": ">= 10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@napi-rs/canvas-darwin-arm64": {
|
||||||
|
"version": "0.1.70",
|
||||||
|
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-arm64/-/canvas-darwin-arm64-0.1.70.tgz",
|
||||||
|
"integrity": "sha512-4pPGyXetHIHkw2TOJHujt3mkCP8LdDu8+CT15ld9Id39c752RcI0amDHSuMLMQfAjvusA9B5kKxazwjMGjEJpQ==",
|
||||||
|
"cpu": [
|
||||||
|
"arm64"
|
||||||
|
],
|
||||||
|
"license": "MIT",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"darwin"
|
||||||
|
],
|
||||||
|
"engines": {
|
||||||
|
"node": ">= 10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@napi-rs/canvas-darwin-x64": {
|
||||||
|
"version": "0.1.70",
|
||||||
|
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-x64/-/canvas-darwin-x64-0.1.70.tgz",
|
||||||
|
"integrity": "sha512-+2N6Os9LbkmDMHL+raknrUcLQhsXzc5CSXRbXws9C3pv/mjHRVszQ9dhFUUe9FjfPhCJznO6USVdwOtu7pOrzQ==",
|
||||||
|
"cpu": [
|
||||||
|
"x64"
|
||||||
|
],
|
||||||
|
"license": "MIT",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"darwin"
|
||||||
|
],
|
||||||
|
"engines": {
|
||||||
|
"node": ">= 10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@napi-rs/canvas-linux-arm-gnueabihf": {
|
||||||
|
"version": "0.1.70",
|
||||||
|
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm-gnueabihf/-/canvas-linux-arm-gnueabihf-0.1.70.tgz",
|
||||||
|
"integrity": "sha512-QjscX9OaKq/990sVhSMj581xuqLgiaPVMjjYvWaCmAJRkNQ004QfoSMEm3FoTqM4DRoquP8jvuEXScVJsc1rqQ==",
|
||||||
|
"cpu": [
|
||||||
|
"arm"
|
||||||
|
],
|
||||||
|
"license": "MIT",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"linux"
|
||||||
|
],
|
||||||
|
"engines": {
|
||||||
|
"node": ">= 10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@napi-rs/canvas-linux-arm64-gnu": {
|
||||||
|
"version": "0.1.70",
|
||||||
|
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-gnu/-/canvas-linux-arm64-gnu-0.1.70.tgz",
|
||||||
|
"integrity": "sha512-LNakMOwwqwiHIwMpnMAbFRczQMQ7TkkMyATqFCOtUJNlE6LPP/QiUj/mlFrNbUn/hctqShJ60gWEb52ZTALbVw==",
|
||||||
|
"cpu": [
|
||||||
|
"arm64"
|
||||||
|
],
|
||||||
|
"license": "MIT",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"linux"
|
||||||
|
],
|
||||||
|
"engines": {
|
||||||
|
"node": ">= 10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@napi-rs/canvas-linux-arm64-musl": {
|
||||||
|
"version": "0.1.70",
|
||||||
|
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-musl/-/canvas-linux-arm64-musl-0.1.70.tgz",
|
||||||
|
"integrity": "sha512-wBTOllEYNfJCHOdZj9v8gLzZ4oY3oyPX8MSRvaxPm/s7RfEXxCyZ8OhJ5xAyicsDdbE5YBZqdmaaeP5+xKxvtg==",
|
||||||
|
"cpu": [
|
||||||
|
"arm64"
|
||||||
|
],
|
||||||
|
"license": "MIT",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"linux"
|
||||||
|
],
|
||||||
|
"engines": {
|
||||||
|
"node": ">= 10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@napi-rs/canvas-linux-riscv64-gnu": {
|
||||||
|
"version": "0.1.70",
|
||||||
|
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-riscv64-gnu/-/canvas-linux-riscv64-gnu-0.1.70.tgz",
|
||||||
|
"integrity": "sha512-GVUUPC8TuuFqHip0rxHkUqArQnlzmlXmTEBuXAWdgCv85zTCFH8nOHk/YCF5yo0Z2eOm8nOi90aWs0leJ4OE5Q==",
|
||||||
|
"cpu": [
|
||||||
|
"riscv64"
|
||||||
|
],
|
||||||
|
"license": "MIT",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"linux"
|
||||||
|
],
|
||||||
|
"engines": {
|
||||||
|
"node": ">= 10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@napi-rs/canvas-linux-x64-gnu": {
|
||||||
|
"version": "0.1.70",
|
||||||
|
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-gnu/-/canvas-linux-x64-gnu-0.1.70.tgz",
|
||||||
|
"integrity": "sha512-/kvUa2lZRwGNyfznSn5t1ShWJnr/m5acSlhTV3eXECafObjl0VBuA1HJw0QrilLpb4Fe0VLywkpD1NsMoVDROQ==",
|
||||||
|
"cpu": [
|
||||||
|
"x64"
|
||||||
|
],
|
||||||
|
"license": "MIT",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"linux"
|
||||||
|
],
|
||||||
|
"engines": {
|
||||||
|
"node": ">= 10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@napi-rs/canvas-linux-x64-musl": {
|
||||||
|
"version": "0.1.70",
|
||||||
|
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-musl/-/canvas-linux-x64-musl-0.1.70.tgz",
|
||||||
|
"integrity": "sha512-aqlv8MLpycoMKRmds7JWCfVwNf1fiZxaU7JwJs9/ExjTD8lX2KjsO7CTeAj5Cl4aEuzxUWbJPUUE2Qu9cZ1vfg==",
|
||||||
|
"cpu": [
|
||||||
|
"x64"
|
||||||
|
],
|
||||||
|
"license": "MIT",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"linux"
|
||||||
|
],
|
||||||
|
"engines": {
|
||||||
|
"node": ">= 10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@napi-rs/canvas-win32-x64-msvc": {
|
||||||
|
"version": "0.1.70",
|
||||||
|
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-win32-x64-msvc/-/canvas-win32-x64-msvc-0.1.70.tgz",
|
||||||
|
"integrity": "sha512-Q9QU3WIpwBTVHk4cPfBjGHGU4U0llQYRXgJtFtYqqGNEOKVN4OT6PQ+ve63xwIPODMpZ0HHyj/KLGc9CWc3EtQ==",
|
||||||
|
"cpu": [
|
||||||
|
"x64"
|
||||||
|
],
|
||||||
|
"license": "MIT",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"win32"
|
||||||
|
],
|
||||||
|
"engines": {
|
||||||
|
"node": ">= 10"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/@nodelib/fs.scandir": {
|
"node_modules/@nodelib/fs.scandir": {
|
||||||
"version": "2.1.5",
|
"version": "2.1.5",
|
||||||
"resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz",
|
"resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz",
|
||||||
|
@ -2001,7 +2197,7 @@
|
||||||
"version": "8.14.0",
|
"version": "8.14.0",
|
||||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.14.0.tgz",
|
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.14.0.tgz",
|
||||||
"integrity": "sha512-cl669nCJTZBsL97OF4kUQm5g5hC2uihk0NxY3WENAC0TYdILVkAyHymAntgxGkl7K+t0cXIrH5siy5S4XkFycA==",
|
"integrity": "sha512-cl669nCJTZBsL97OF4kUQm5g5hC2uihk0NxY3WENAC0TYdILVkAyHymAntgxGkl7K+t0cXIrH5siy5S4XkFycA==",
|
||||||
"dev": true,
|
"devOptional": true,
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"bin": {
|
"bin": {
|
||||||
"acorn": "bin/acorn"
|
"acorn": "bin/acorn"
|
||||||
|
@ -2185,6 +2381,14 @@
|
||||||
"devOptional": true,
|
"devOptional": true,
|
||||||
"license": "MIT/X11"
|
"license": "MIT/X11"
|
||||||
},
|
},
|
||||||
|
"node_modules/buffer-from": {
|
||||||
|
"version": "1.1.2",
|
||||||
|
"resolved": "https://registry.npmjs.org/buffer-from/-/buffer-from-1.1.2.tgz",
|
||||||
|
"integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==",
|
||||||
|
"license": "MIT",
|
||||||
|
"optional": true,
|
||||||
|
"peer": true
|
||||||
|
},
|
||||||
"node_modules/callsites": {
|
"node_modules/callsites": {
|
||||||
"version": "3.1.0",
|
"version": "3.1.0",
|
||||||
"resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz",
|
"resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz",
|
||||||
|
@ -2802,6 +3006,13 @@
|
||||||
"reusify": "^1.0.4"
|
"reusify": "^1.0.4"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/fflate": {
|
||||||
|
"version": "0.8.2",
|
||||||
|
"resolved": "https://registry.npmjs.org/fflate/-/fflate-0.8.2.tgz",
|
||||||
|
"integrity": "sha512-cPJU47OaAoCbg0pBvzsgpTPhmhqI5eJjh/JIu8tPj5q+T7iLvW/JAYUqmE7KOB4R1ZyEhzBaIQpQpardBF5z8A==",
|
||||||
|
"dev": true,
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
"node_modules/file-entry-cache": {
|
"node_modules/file-entry-cache": {
|
||||||
"version": "8.0.0",
|
"version": "8.0.0",
|
||||||
"resolved": "https://registry.npmjs.org/file-entry-cache/-/file-entry-cache-8.0.0.tgz",
|
"resolved": "https://registry.npmjs.org/file-entry-cache/-/file-entry-cache-8.0.0.tgz",
|
||||||
|
@ -4835,6 +5046,18 @@
|
||||||
"node": ">=8"
|
"node": ">=8"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/pdfjs-dist": {
|
||||||
|
"version": "5.2.133",
|
||||||
|
"resolved": "https://registry.npmjs.org/pdfjs-dist/-/pdfjs-dist-5.2.133.tgz",
|
||||||
|
"integrity": "sha512-abE6ZWDxztt+gGFzfm4bX2ggfxUk9wsDEoFzIJm9LozaY3JdXR7jyLK4Bjs+XLXplCduuWS1wGhPC4tgTn/kzg==",
|
||||||
|
"license": "Apache-2.0",
|
||||||
|
"engines": {
|
||||||
|
"node": ">=20.16.0 || >=22.3.0"
|
||||||
|
},
|
||||||
|
"optionalDependencies": {
|
||||||
|
"@napi-rs/canvas": "^0.1.67"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/picocolors": {
|
"node_modules/picocolors": {
|
||||||
"version": "1.1.1",
|
"version": "1.1.1",
|
||||||
"resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz",
|
"resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz",
|
||||||
|
@ -5745,6 +5968,17 @@
|
||||||
"node": ">=8"
|
"node": ">=8"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/source-map": {
|
||||||
|
"version": "0.6.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz",
|
||||||
|
"integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==",
|
||||||
|
"license": "BSD-3-Clause",
|
||||||
|
"optional": true,
|
||||||
|
"peer": true,
|
||||||
|
"engines": {
|
||||||
|
"node": ">=0.10.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/source-map-js": {
|
"node_modules/source-map-js": {
|
||||||
"version": "1.2.1",
|
"version": "1.2.1",
|
||||||
"resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz",
|
"resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz",
|
||||||
|
@ -5754,6 +5988,18 @@
|
||||||
"node": ">=0.10.0"
|
"node": ">=0.10.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/source-map-support": {
|
||||||
|
"version": "0.5.21",
|
||||||
|
"resolved": "https://registry.npmjs.org/source-map-support/-/source-map-support-0.5.21.tgz",
|
||||||
|
"integrity": "sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w==",
|
||||||
|
"license": "MIT",
|
||||||
|
"optional": true,
|
||||||
|
"peer": true,
|
||||||
|
"dependencies": {
|
||||||
|
"buffer-from": "^1.0.0",
|
||||||
|
"source-map": "^0.6.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/space-separated-tokens": {
|
"node_modules/space-separated-tokens": {
|
||||||
"version": "2.0.2",
|
"version": "2.0.2",
|
||||||
"resolved": "https://registry.npmjs.org/space-separated-tokens/-/space-separated-tokens-2.0.2.tgz",
|
"resolved": "https://registry.npmjs.org/space-separated-tokens/-/space-separated-tokens-2.0.2.tgz",
|
||||||
|
@ -5851,6 +6097,34 @@
|
||||||
"node": ">=6"
|
"node": ">=6"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/terser": {
|
||||||
|
"version": "5.39.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/terser/-/terser-5.39.1.tgz",
|
||||||
|
"integrity": "sha512-Mm6+uad0ZuDtcV8/4uOZQDQ8RuiC5Pu+iZRedJtF7yA/27sPL7d++In/AJKpWZlU3SYMPPkVfwetn6sgZ66pUA==",
|
||||||
|
"license": "BSD-2-Clause",
|
||||||
|
"optional": true,
|
||||||
|
"peer": true,
|
||||||
|
"dependencies": {
|
||||||
|
"@jridgewell/source-map": "^0.3.3",
|
||||||
|
"acorn": "^8.8.2",
|
||||||
|
"commander": "^2.20.0",
|
||||||
|
"source-map-support": "~0.5.20"
|
||||||
|
},
|
||||||
|
"bin": {
|
||||||
|
"terser": "bin/terser"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/terser/node_modules/commander": {
|
||||||
|
"version": "2.20.3",
|
||||||
|
"resolved": "https://registry.npmjs.org/commander/-/commander-2.20.3.tgz",
|
||||||
|
"integrity": "sha512-GpVkmM8vF2vQUkj2LvZmD35JxeJOLCwJ9cUkugyk2nuhbv3+mJvpLYYt+0+USMxE+oj+ey/lJEnhZw75x/OMcQ==",
|
||||||
|
"license": "MIT",
|
||||||
|
"optional": true,
|
||||||
|
"peer": true
|
||||||
|
},
|
||||||
"node_modules/textlinestream": {
|
"node_modules/textlinestream": {
|
||||||
"version": "1.1.1",
|
"version": "1.1.1",
|
||||||
"resolved": "https://registry.npmjs.org/textlinestream/-/textlinestream-1.1.1.tgz",
|
"resolved": "https://registry.npmjs.org/textlinestream/-/textlinestream-1.1.1.tgz",
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
"type": "module",
|
"type": "module",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"dev": "vite",
|
"dev": "vite",
|
||||||
"build": "tsc -b && vite build",
|
"build": "npm run format && tsc -b && vite build",
|
||||||
"format": "eslint . && prettier --write .",
|
"format": "eslint . && prettier --write .",
|
||||||
"lint": "eslint .",
|
"lint": "eslint .",
|
||||||
"preview": "vite preview"
|
"preview": "vite preview"
|
||||||
|
@ -21,6 +21,7 @@
|
||||||
"dexie": "^4.0.11",
|
"dexie": "^4.0.11",
|
||||||
"highlight.js": "^11.10.0",
|
"highlight.js": "^11.10.0",
|
||||||
"katex": "^0.16.15",
|
"katex": "^0.16.15",
|
||||||
|
"pdfjs-dist": "^5.2.133",
|
||||||
"postcss": "^8.4.49",
|
"postcss": "^8.4.49",
|
||||||
"react": "^18.3.1",
|
"react": "^18.3.1",
|
||||||
"react-dom": "^18.3.1",
|
"react-dom": "^18.3.1",
|
||||||
|
@ -47,6 +48,7 @@
|
||||||
"eslint": "^9.17.0",
|
"eslint": "^9.17.0",
|
||||||
"eslint-plugin-react-hooks": "^5.0.0",
|
"eslint-plugin-react-hooks": "^5.0.0",
|
||||||
"eslint-plugin-react-refresh": "^0.4.16",
|
"eslint-plugin-react-refresh": "^0.4.16",
|
||||||
|
"fflate": "^0.8.2",
|
||||||
"globals": "^15.14.0",
|
"globals": "^15.14.0",
|
||||||
"prettier": "^3.4.2",
|
"prettier": "^3.4.2",
|
||||||
"sass-embedded": "^1.83.4",
|
"sass-embedded": "^1.83.4",
|
||||||
|
|
|
@ -16,6 +16,8 @@ export const CONFIG_DEFAULT = {
|
||||||
showTokensPerSecond: false,
|
showTokensPerSecond: false,
|
||||||
showThoughtInProgress: false,
|
showThoughtInProgress: false,
|
||||||
excludeThoughtOnReq: true,
|
excludeThoughtOnReq: true,
|
||||||
|
pasteLongTextToFileLen: 2500,
|
||||||
|
pdfAsImage: false,
|
||||||
// make sure these default values are in sync with `common.h`
|
// make sure these default values are in sync with `common.h`
|
||||||
samplers: 'edkypmxt',
|
samplers: 'edkypmxt',
|
||||||
temperature: 0.8,
|
temperature: 0.8,
|
||||||
|
@ -43,6 +45,8 @@ export const CONFIG_DEFAULT = {
|
||||||
export const CONFIG_INFO: Record<string, string> = {
|
export const CONFIG_INFO: Record<string, string> = {
|
||||||
apiKey: 'Set the API Key if you are using --api-key option for the server.',
|
apiKey: 'Set the API Key if you are using --api-key option for the server.',
|
||||||
systemMessage: 'The starting message that defines how model should behave.',
|
systemMessage: 'The starting message that defines how model should behave.',
|
||||||
|
pasteLongTextToFileLen:
|
||||||
|
'On pasting long text, it will be converted to a file. You can control the file length by setting the value of this parameter. Value 0 means disable.',
|
||||||
samplers:
|
samplers:
|
||||||
'The order at which samplers are applied, in simplified way. Default is "dkypmxt": dry->top_k->typ_p->top_p->min_p->xtc->temperature',
|
'The order at which samplers are applied, in simplified way. Default is "dkypmxt": dry->top_k->typ_p->top_p->min_p->xtc->temperature',
|
||||||
temperature:
|
temperature:
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import { useEffect, useMemo, useRef, useState } from 'react';
|
import { ClipboardEvent, useEffect, useMemo, useRef, useState } from 'react';
|
||||||
import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context';
|
import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context';
|
||||||
import ChatMessage from './ChatMessage';
|
import ChatMessage from './ChatMessage';
|
||||||
import { CanvasType, Message, PendingMessage } from '../utils/types';
|
import { CanvasType, Message, PendingMessage } from '../utils/types';
|
||||||
|
@ -306,6 +306,7 @@ function ChatInput({
|
||||||
onStop: () => void;
|
onStop: () => void;
|
||||||
isGenerating: boolean;
|
isGenerating: boolean;
|
||||||
}) {
|
}) {
|
||||||
|
const { config } = useAppContext();
|
||||||
const [isDrag, setIsDrag] = useState(false);
|
const [isDrag, setIsDrag] = useState(false);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -328,6 +329,38 @@ function ChatInput({
|
||||||
{({ getRootProps, getInputProps }) => (
|
{({ getRootProps, getInputProps }) => (
|
||||||
<div
|
<div
|
||||||
className="flex flex-col rounded-xl border-1 border-base-content/30 p-3 w-full"
|
className="flex flex-col rounded-xl border-1 border-base-content/30 p-3 w-full"
|
||||||
|
// when a file is pasted to the input, we handle it here
|
||||||
|
// if a text is pasted, and if it is long text, we will convert it to a file
|
||||||
|
onPasteCapture={(e: ClipboardEvent<HTMLInputElement>) => {
|
||||||
|
const text = e.clipboardData.getData('text/plain');
|
||||||
|
if (
|
||||||
|
text.length > 0 &&
|
||||||
|
config.pasteLongTextToFileLen > 0 &&
|
||||||
|
text.length > config.pasteLongTextToFileLen
|
||||||
|
) {
|
||||||
|
// if the text is too long, we will convert it to a file
|
||||||
|
extraContext.addItems([
|
||||||
|
{
|
||||||
|
type: 'context',
|
||||||
|
name: 'Pasted Content',
|
||||||
|
content: text,
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
e.preventDefault();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// if a file is pasted, we will handle it here
|
||||||
|
const files = Array.from(e.clipboardData.items)
|
||||||
|
.filter((item) => item.kind === 'file')
|
||||||
|
.map((item) => item.getAsFile())
|
||||||
|
.filter((file) => file !== null);
|
||||||
|
|
||||||
|
if (files.length > 0) {
|
||||||
|
e.preventDefault();
|
||||||
|
extraContext.onFileAdded(files);
|
||||||
|
}
|
||||||
|
}}
|
||||||
{...getRootProps()}
|
{...getRootProps()}
|
||||||
>
|
>
|
||||||
{!isGenerating && (
|
{!isGenerating && (
|
||||||
|
|
|
@ -100,6 +100,16 @@ const SETTING_SECTIONS: SettingSection[] = [
|
||||||
key,
|
key,
|
||||||
}) as SettingFieldInput
|
}) as SettingFieldInput
|
||||||
),
|
),
|
||||||
|
{
|
||||||
|
type: SettingInputType.SHORT_INPUT,
|
||||||
|
label: 'Paste length to file',
|
||||||
|
key: 'pasteLongTextToFileLen',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: SettingInputType.CHECKBOX,
|
||||||
|
label: 'Parse PDF as image instead of text',
|
||||||
|
key: 'pdfAsImage',
|
||||||
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -452,10 +462,10 @@ function SettingsModalLongInput({
|
||||||
label?: string;
|
label?: string;
|
||||||
}) {
|
}) {
|
||||||
return (
|
return (
|
||||||
<label className="form-control mb-2">
|
<label className="form-control">
|
||||||
<div className="label inline">{label || configKey}</div>
|
<div className="label inline text-sm">{label || configKey}</div>
|
||||||
<textarea
|
<textarea
|
||||||
className="textarea textarea-bordered h-24"
|
className="textarea textarea-bordered h-24 mb-2"
|
||||||
placeholder={`Default: ${CONFIG_DEFAULT[configKey] || 'none'}`}
|
placeholder={`Default: ${CONFIG_DEFAULT[configKey] || 'none'}`}
|
||||||
value={value}
|
value={value}
|
||||||
onChange={(e) => onChange(e.target.value)}
|
onChange={(e) => onChange(e.target.value)}
|
||||||
|
@ -482,9 +492,7 @@ function SettingsModalShortInput({
|
||||||
<>
|
<>
|
||||||
{/* on mobile, we simply show the help message here */}
|
{/* on mobile, we simply show the help message here */}
|
||||||
{helpMsg && (
|
{helpMsg && (
|
||||||
<div className="block md:hidden mb-1">
|
<div className="block mb-1 opacity-75">
|
||||||
<b>{label || configKey}</b>
|
|
||||||
<br />
|
|
||||||
<p className="text-xs">{helpMsg}</p>
|
<p className="text-xs">{helpMsg}</p>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
@ -493,11 +501,6 @@ function SettingsModalShortInput({
|
||||||
<div tabIndex={0} role="button" className="font-bold hidden md:block">
|
<div tabIndex={0} role="button" className="font-bold hidden md:block">
|
||||||
{label || configKey}
|
{label || configKey}
|
||||||
</div>
|
</div>
|
||||||
{helpMsg && (
|
|
||||||
<div className="dropdown-content menu bg-base-100 rounded-box z-10 w-64 p-2 shadow mt-4">
|
|
||||||
{helpMsg}
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</div>
|
</div>
|
||||||
<input
|
<input
|
||||||
type="text"
|
type="text"
|
||||||
|
|
|
@ -2,6 +2,17 @@ import { useState } from 'react';
|
||||||
import { MessageExtra } from '../utils/types';
|
import { MessageExtra } from '../utils/types';
|
||||||
import toast from 'react-hot-toast';
|
import toast from 'react-hot-toast';
|
||||||
import { useAppContext } from '../utils/app.context';
|
import { useAppContext } from '../utils/app.context';
|
||||||
|
import * as pdfjs from 'pdfjs-dist';
|
||||||
|
import pdfjsWorkerSrc from 'pdfjs-dist/build/pdf.worker.min.mjs?url';
|
||||||
|
import { TextContent, TextItem } from 'pdfjs-dist/types/src/display/api';
|
||||||
|
|
||||||
|
pdfjs.GlobalWorkerOptions.workerSrc = pdfjsWorkerSrc;
|
||||||
|
|
||||||
|
// This file handles uploading extra context items (a.k.a files)
|
||||||
|
// It allows processing these kinds of files:
|
||||||
|
// - image files (converted to base64)
|
||||||
|
// - text files (including code files)
|
||||||
|
// - pdf (converted to text)
|
||||||
|
|
||||||
// Interface describing the API returned by the hook
|
// Interface describing the API returned by the hook
|
||||||
export interface ChatExtraContextApi {
|
export interface ChatExtraContextApi {
|
||||||
|
@ -13,7 +24,7 @@ export interface ChatExtraContextApi {
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useChatExtraContext(): ChatExtraContextApi {
|
export function useChatExtraContext(): ChatExtraContextApi {
|
||||||
const { serverProps } = useAppContext();
|
const { serverProps, config } = useAppContext();
|
||||||
const [items, setItems] = useState<MessageExtra[]>([]);
|
const [items, setItems] = useState<MessageExtra[]>([]);
|
||||||
|
|
||||||
const addItems = (newItems: MessageExtra[]) => {
|
const addItems = (newItems: MessageExtra[]) => {
|
||||||
|
@ -28,6 +39,8 @@ export function useChatExtraContext(): ChatExtraContextApi {
|
||||||
setItems([]);
|
setItems([]);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const isSupportVision = serverProps?.modalities?.vision;
|
||||||
|
|
||||||
const onFileAdded = (files: File[]) => {
|
const onFileAdded = (files: File[]) => {
|
||||||
for (const file of files) {
|
for (const file of files) {
|
||||||
const mimeType = file.type;
|
const mimeType = file.type;
|
||||||
|
@ -38,7 +51,7 @@ export function useChatExtraContext(): ChatExtraContextApi {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mimeType.startsWith('image/')) {
|
if (mimeType.startsWith('image/')) {
|
||||||
if (!serverProps?.modalities?.vision) {
|
if (!isSupportVision) {
|
||||||
toast.error('Multimodal is not supported by this server or model.');
|
toast.error('Multimodal is not supported by this server or model.');
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -69,7 +82,43 @@ export function useChatExtraContext(): ChatExtraContextApi {
|
||||||
toast.error('Video and audio files are not supported yet.');
|
toast.error('Video and audio files are not supported yet.');
|
||||||
break;
|
break;
|
||||||
} else if (mimeType.startsWith('application/pdf')) {
|
} else if (mimeType.startsWith('application/pdf')) {
|
||||||
toast.error('PDF files are not supported yet.');
|
if (config.pdfAsImage && !isSupportVision) {
|
||||||
|
toast(
|
||||||
|
'Multimodal is not supported, PDF will be converted to text instead of image.'
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const promise =
|
||||||
|
config.pdfAsImage && isSupportVision
|
||||||
|
? convertPDFToImage(file).then((base64Urls) => {
|
||||||
|
addItems(
|
||||||
|
base64Urls.map((base64Url) => ({
|
||||||
|
type: 'imageFile',
|
||||||
|
name: file.name,
|
||||||
|
base64Url,
|
||||||
|
}))
|
||||||
|
);
|
||||||
|
})
|
||||||
|
: convertPDFToText(file).then((content) => {
|
||||||
|
if (isSupportVision) {
|
||||||
|
toast.success(
|
||||||
|
'PDF file converted to text. You can also convert it to image, see in Settings.'
|
||||||
|
);
|
||||||
|
}
|
||||||
|
addItems([
|
||||||
|
{
|
||||||
|
type: 'textFile',
|
||||||
|
name: file.name,
|
||||||
|
content,
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
promise.catch((error) => {
|
||||||
|
console.error(error);
|
||||||
|
toast.error('Failed to parse PDF file.');
|
||||||
|
});
|
||||||
break;
|
break;
|
||||||
} else {
|
} else {
|
||||||
// Because there can be many text file types (like code file), we will not check the mime type
|
// Because there can be many text file types (like code file), we will not check the mime type
|
||||||
|
@ -105,11 +154,69 @@ export function useChatExtraContext(): ChatExtraContextApi {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function getFileAsBuffer(file: File): Promise<ArrayBuffer> {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
const reader = new FileReader();
|
||||||
|
reader.onload = (event) => {
|
||||||
|
if (event.target?.result) {
|
||||||
|
resolve(event.target.result as ArrayBuffer);
|
||||||
|
} else {
|
||||||
|
reject(new Error('Failed to read file.'));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
reader.readAsArrayBuffer(file);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async function convertPDFToText(file: File): Promise<string> {
|
||||||
|
const buffer = await getFileAsBuffer(file);
|
||||||
|
const pdf = await pdfjs.getDocument(buffer).promise;
|
||||||
|
const numPages = pdf.numPages;
|
||||||
|
const textContentPromises: Promise<TextContent>[] = [];
|
||||||
|
for (let i = 1; i <= numPages; i++) {
|
||||||
|
textContentPromises.push(
|
||||||
|
pdf.getPage(i).then((page) => page.getTextContent())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
const textContents = await Promise.all(textContentPromises);
|
||||||
|
const textItems = textContents.flatMap((textContent: TextContent) =>
|
||||||
|
textContent.items.map((item) => (item as TextItem).str ?? '')
|
||||||
|
);
|
||||||
|
return textItems.join('\n');
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns list of base64 images
|
||||||
|
async function convertPDFToImage(file: File): Promise<string[]> {
|
||||||
|
const buffer = await getFileAsBuffer(file);
|
||||||
|
const doc = await pdfjs.getDocument(buffer).promise;
|
||||||
|
const pages: Promise<string>[] = [];
|
||||||
|
|
||||||
|
for (let i = 1; i <= doc.numPages; i++) {
|
||||||
|
const page = await doc.getPage(i);
|
||||||
|
const viewport = page.getViewport({ scale: 1.5 });
|
||||||
|
const canvas = document.createElement('canvas');
|
||||||
|
const ctx = canvas.getContext('2d');
|
||||||
|
canvas.width = viewport.width;
|
||||||
|
canvas.height = viewport.height;
|
||||||
|
if (!ctx) {
|
||||||
|
throw new Error('Failed to get 2D context from canvas');
|
||||||
|
}
|
||||||
|
const task = page.render({ canvasContext: ctx, viewport: viewport });
|
||||||
|
pages.push(
|
||||||
|
task.promise.then(() => {
|
||||||
|
return canvas.toDataURL();
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return await Promise.all(pages);
|
||||||
|
}
|
||||||
|
|
||||||
// WARN: vibe code below
|
// WARN: vibe code below
|
||||||
// This code is a heuristic to determine if a string is likely not binary.
|
// This code is a heuristic to determine if a string is likely not binary.
|
||||||
// It is necessary because input file can have various mime types which we don't have time to investigate.
|
// It is necessary because input file can have various mime types which we don't have time to investigate.
|
||||||
// For example, a python file can be text/plain, application/x-python, etc.
|
// For example, a python file can be text/plain, application/x-python, etc.
|
||||||
export function isLikelyNotBinary(str: string): boolean {
|
function isLikelyNotBinary(str: string): boolean {
|
||||||
const options = {
|
const options = {
|
||||||
prefixLength: 1024 * 10, // Check the first 10KB of the string
|
prefixLength: 1024 * 10, // Check the first 10KB of the string
|
||||||
suspiciousCharThresholdRatio: 0.15, // Allow up to 15% suspicious chars
|
suspiciousCharThresholdRatio: 0.15, // Allow up to 15% suspicious chars
|
||||||
|
|
|
@ -3,11 +3,11 @@ import react from '@vitejs/plugin-react';
|
||||||
import { viteSingleFile } from 'vite-plugin-singlefile';
|
import { viteSingleFile } from 'vite-plugin-singlefile';
|
||||||
import path from 'node:path';
|
import path from 'node:path';
|
||||||
import fs from 'node:fs';
|
import fs from 'node:fs';
|
||||||
import zlib from 'node:zlib';
|
import * as fflate from 'fflate';
|
||||||
|
|
||||||
/* eslint-disable */
|
/* eslint-disable */
|
||||||
|
|
||||||
const MAX_BUNDLE_SIZE = 1.5 * 1024 * 1024; // only increase when absolutely necessary
|
const MAX_BUNDLE_SIZE = 2 * 1024 * 1024; // only increase when absolutely necessary
|
||||||
|
|
||||||
const GUIDE_FOR_FRONTEND = `
|
const GUIDE_FOR_FRONTEND = `
|
||||||
<!--
|
<!--
|
||||||
|
@ -33,9 +33,10 @@ const BUILD_PLUGINS = [
|
||||||
},
|
},
|
||||||
writeBundle() {
|
writeBundle() {
|
||||||
const outputIndexHtml = path.join(config.build.outDir, 'index.html');
|
const outputIndexHtml = path.join(config.build.outDir, 'index.html');
|
||||||
const content =
|
let content =
|
||||||
GUIDE_FOR_FRONTEND + '\n' + fs.readFileSync(outputIndexHtml, 'utf-8');
|
GUIDE_FOR_FRONTEND + '\n' + fs.readFileSync(outputIndexHtml, 'utf-8');
|
||||||
const compressed = zlib.gzipSync(Buffer.from(content, 'utf-8'), {
|
content = content.replace(/\r/g, ''); // remove windows-style line endings
|
||||||
|
const compressed = fflate.gzipSync(Buffer.from(content, 'utf-8'), {
|
||||||
level: 9,
|
level: 9,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue